diff --git a/.codex/skills/pto-gym-vpto-validation/SKILL.md b/.codex/skills/pto-gym-vpto-validation/SKILL.md index 0e1451a614..721df32b6f 100644 --- a/.codex/skills/pto-gym-vpto-validation/SKILL.md +++ b/.codex/skills/pto-gym-vpto-validation/SKILL.md @@ -1,6 +1,6 @@ --- name: pto-gym-vpto-validation -description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks to run PTO-Gym SIM or board validation from the current source tree. Always force PTOAS onto the VPTO LLVM path instead of relying on the repo default backend. +description: Run bundled PTO-Gym exercise/validation cases. Use when the user explicitly asks for PTO-Gym, 3rdparty/PTO-Gym, or the PTO-Gym validation scripts. Always force PTOAS onto the VPTO path instead of relying on the repo default backend. --- # PTO-Gym VPTO Validation @@ -8,20 +8,20 @@ description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks Use this skill when the task is specifically about: - running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh` - running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh` -- validating PTO-Gym cases from this PTOAS source tree +- validating bundled PTO-Gym exercise cases ## Required Rule When PTO-Gym is run from this repo, do not rely on the default PTOAS backend. Always pass PTOAS flags that force the VPTO LLVM path. -The current `ptoas` CLI spellings in this repo are `--pto-backend=vpto` and -`--vpto-emit-hivm-llvm`; do not shorten `--pto-backend` to `--backend`. +The current `ptoas` CLI spelling in this repo is `--pto-backend=vpto`; do not +shorten `--pto-backend` to `--backend`. Use: ```bash -PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +PTOAS_FLAGS='--pto-backend=vpto --pto-arch a5' ``` If the caller already provides `PTOAS_FLAGS`, make sure these options are still @@ -44,7 +44,7 @@ Typical simulator environment: source /home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2/set_env.sh export ASCEND_HOME_PATH=/home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2 export PTOAS_BIN=$PWD/build/tools/ptoas/ptoas -export PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +export PTOAS_FLAGS='--pto-backend=vpto --pto-arch a5' ``` ## Canonical Commands diff --git a/README.md b/README.md index 0d8399783f..b3a547ab04 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,11 @@ ptoas test/lit/pto/empty_func.pto --pto-arch=a5 -o outputfile.cpp # 指定构建 Level(level3 会禁用 PlanMemory/InsertSync) ptoas test/lit/pto/empty_func.pto --pto-level=level3 -o outputfile.cpp +# 启用实验性 VMI -> VPTO 语义 pipeline +# 该模式要求 --pto-backend=vpto,或输入 IR 中带 pto.backend = "vpto" +# public function signature 不能直接暴露 !pto.vmi.* 类型 +ptoas test/lit/vmi/vmi_ptoas_cli_pipeline.pto --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto -o - + # 查看当前 ptoas release 版本号 ptoas --version diff --git a/docs/designs/vmi-e2b-scale-broadcast-optimization.md b/docs/designs/vmi-e2b-scale-broadcast-optimization.md new file mode 100644 index 0000000000..9777162f3b --- /dev/null +++ b/docs/designs/vmi-e2b-scale-broadcast-optimization.md @@ -0,0 +1,992 @@ +# VMI E2B Scale Broadcast Optimization Study + +本文推演 VMI 是否能把 block quant 中的 scale broadcast 自动优化成 +`E2B_B16` load。结论是: + +```text +group_slot_load + group_broadcast 足以表达逻辑语义。 + +它不足以单独触发 E2B,因为 E2B 是某个 physical chunk layout 下的 +materialization,不是 dense logical broadcast 的直接 lowering。 + +如果后续 layout 已经由 consumer requirement 或 target-specific layout +optimization 选成 E2B-compatible 形态,vmi-to-vpto 可以把对应 chunk lower +成 E2B。 + +如果想从普通 dense quant IR 自动得到 CCE 的 DINTLV/E2B 形态,需要一个 +target-specific layout optimization/cost selection 阶段整体选择这套计划。 +``` + +## 1. Logical Quant Semantics + +`ComputeY1ToFP8` 的 surface VMI 语义应保持 dense quant: + +```text +for i in 0..255: + y[i] = fp8(x[i] * scale[i / 32]) +``` + +也就是 8 个 scale,每个覆盖 32 个 dense logical lanes: + +```text +s0 x32, s1 x32, ..., s7 x32 +``` + +对应 VMI 形态是: + +```text +%x = pto.vmi.load %x_base[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + +%scale_slots = pto.vmi.group_slot_load %scale_base[%scale_off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xbf16> + +%scale = pto.vmi.group_broadcast %scale_slots {num_groups = 8} + : !pto.vmi.vreg<8xbf16> -> !pto.vmi.vreg<256xbf16> +``` + +This form is the canonical logical IR. The source scale should be BF16 payload, +not FP16, because the CCE implementation loads `uint16_t` values and later +reinterprets them as `vector_bf16`. + +`num_groups = 16` would express a different algorithm: + +```text +16 scale values, each covering 16 dense lanes +``` + +That is not equivalent unless the input memory redundantly stores +`s0, s0, s1, s1, ...`, which is not what the CCE kernel does. + +## 2. E2B_B16 Semantics + +`E2B_B16` is a VPTO load distribution mode. For a b16 destination register it +loads 8 source elements and expands each one to 16 consecutive destination +lanes: + +```text +dst[j] = src[floor(j / 16)] for j = 0..127 +``` + +The result is: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +So `E2B_B16` does not directly materialize the dense VMI broadcast +`8 -> 256`. It materializes one 128-lane physical view that becomes useful only +after the x data and later f32 computation have been split into compatible +physical chunks. + +## 3. Why CCE Can Use E2B + +The CCE FP16 path uses a physical implementation shape like: + +```text +vlds(x0F16, x1F16, xHalf, stride, DINTLV_B16, POST_UPDATE) +vlds(scaleForMulFP16, scale_base, 0, E2B_B16) + +vcvt(x0_even_f32, x0F16, PART_EVEN) +vcvt(x0_odd_f32, x0F16, PART_ODD) +vcvt(x1_even_f32, x1F16, PART_EVEN) +vcvt(x1_odd_f32, x1F16, PART_ODD) + +vcvt(scale_f32, (vector_bf16 &)scaleForMulFP16, PART_EVEN) +``` + +`DINTLV_B16` splits the dense 256-element row into two 128-lane physical streams. +After each stream is converted from f16 to f32, the computation is effectively +four 64-lane f32 chunks: + +```text +x0 even part +x0 odd part +x1 even part +x1 odd part +``` + +For every one of those chunks, the needed scale pattern is: + +```text +s0 x8, s1 x8, ..., s7 x8 +``` + +`E2B_B16` produces: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +Then `vcvt PART_EVEN` produces: + +```text +s0 x8, s1 x8, ..., s7 x8 +``` + +Because every scale value is duplicated in adjacent even/odd b16 positions, +`PART_EVEN` and `PART_ODD` would produce the same f32 scale chunk. The CCE code +computes the scale chunk once and reuses it for all four x chunks. + +## 4. What Is A Legal Automatic Optimization? + +The following rewrite is not legal as a standalone local rule: + +```text +group_slot_load + group_broadcast(8 -> 256) => E2B_B16 +``` + +It is invalid because the left side is a dense 256-lane logical value, while +`E2B_B16` produces a 128-lane physical value with a different lane repetition +count. + +A legal E2B lowering must be conditional on the assigned physical layout: + +```text +if the broadcasted scale value is required in physical chunks where each chunk +needs s0 x16 ... s7 x16 at b16 width, or s0 x8 ... s7 x8 after bf16->f32, +then that chunk may be materialized with E2B_B16. +``` + +In other words: + +```text +group_slot_load + group_broadcast + is the logical source pattern + +consumer-required or target-selected layout + determines whether any physical chunk is E2B-compatible + +vmi-to-vpto + lowers only those compatible chunks to E2B +``` + +`group_slot_load` alone cannot lower to E2B. A group-slot value has only group +slots as semantic lanes. `E2B_B16` already produces broadcasted physical lanes. +The `group_broadcast` use is required to justify reading those lanes. + +## 5. Layout Selection Boundary + +Deinterleaved layout must not be inferred only because E2B would be cheaper. +The selected layout must be explicit before `vmi-to-vpto`. That layout can come +from either side: + +```text +consumer requirement: + a later op requires a particular layout. + +producer natural layout: + the producing op has a declared, deterministic natural layout that is legal + for all of its uses. +``` + +`group_broadcast` is a materialization op, so it may define or participate in an +E2B-friendly natural layout when that layout is part of the declared layout +contract. That is still a layout-assignment decision, not a hidden +`vmi-to-vpto` peephole. Do not reuse `block_elems` as an ad-hoc broadcast split +knob; `block_elems` belongs to the dense deinterleaved layout contract and has +existing producer/consumer meanings. + +Baseline layout assignment may still choose conservative contiguous layouts even +when a target-specific fused implementation exists. + +Therefore this optimization has two valid implementation levels. + +### 5.1 Compatible-Layout Lowering Shortcut + +If some earlier layout pass has already assigned an E2B-compatible physical +layout, `vmi-to-vpto` may lower the scale chunk with `E2B_B16`. + +This is a local deterministic lowering. It does not discover the CCE plan by +itself. It only avoids a generic `vsldb + vselr` materialization when the +assigned layout has already made the required physical chunk shape explicit. + +### 5.2 Producer Natural Layout + +For simple broadcasts, the producer itself may choose an E2B-friendly natural +layout when that layout satisfies every use. + +Example for b16, using an existing DINTLV-like element-parity layout: + +```text +logical 1 -> 32: + s0 x32 + +layout: + deinterleaved = 2, block_elems = 1 + +physical part 0: + s0 x16 + +physical part 1: + s0 x16 +``` + +The two physical parts can share one E2B materialization or use two identical +E2B materializations. This is a general layout choice for the broadcast result, +not a quant-specific graph rewrite. + +For a uniform `1 -> 32` or per-group `x32` broadcast, `deinterleaved = 2, +block_elems = 1` yields 16 lanes of the same group per physical part and is +closer to an even/odd `DINTLV_B16` data layout. + +For the MX quant scale: + +```text +logical 8 -> 256: + s0 x32, s1 x32, ..., s7 x32 + +layout: + deinterleaved = 2, block_elems = 1 + +physical part 0: + s0 x16, s1 x16, ..., s7 x16 + +physical part 1: + s0 x16, s1 x16, ..., s7 x16 +``` + +Each physical part is directly `E2B_B16`-compatible. +The implementation should run the E2B compatibility query over the assigned lane +mapping. It should not infer a new meaning for `block_elems`. + +### 5.3 Target-Specific Layout Optimization + +To automatically discover the complete CCE plan from canonical dense quant IR, +add an optional target-specific layout optimization before `vmi-to-vpto`. + +That pass may select a cheaper equivalent implementation for the whole quant +subgraph: + +```text +dense x load +f16/bf16 -> f32 conversion +scale group_slot_load + group_broadcast 8 -> 256 +scale bf16 -> f32 conversion +mul +fp32 -> fp8 conversion/store +``` + +The pass must rewrite or annotate the VMI layout-assigned IR so that +`vmi-to-vpto` no longer has to infer the plan from context. + +Expected selected physical plan: + +```text +x load: + vlds DINTLV_B16 into two b16 streams + +scale load: + vlds E2B_B16 into one b16 stream + +scale conversion: + vcvt PART_EVEN into one 64-lane f32 stream + +mul: + reuse that scale f32 stream for the four x f32 chunks +``` + +This is an optimization, not a correctness requirement. If the optimizer does +not fire, the canonical dense VMI program still has a valid generic lowering. + +## 6. Candidate Match Preconditions + +A target-specific optimization may match the CCE-style scale pattern only under +strict conditions: + +```text +scale_slots: + pto.vmi.group_slot_load + num_groups = 8 + source_group_stride = 1 + source element width = 16 bits + semantic type is bf16 or a bitcastable ui16 payload later interpreted as bf16 + +scale broadcast: + pto.vmi.group_broadcast + same num_groups = 8 + dense logical result has 256 b16 lanes for this case + +scale conversion: + bf16 -> f32 + conversion has no rounding/exception behavior that distinguishes duplicated + even and odd source lanes + +x path: + dense logical row has 256 f16 or bf16 lanes + the target plan can legally compute the row as four 64-lane f32 chunks + +uses: + no user observes the intermediate dense scale layout in a way that prevents + rematerialization or chunk reuse +``` + +The optimization should reject or skip the pattern if any of these conditions are +not proven. + +## 7. Correctness Sketch + +Let the logical dense lane be `i`. + +The canonical VMI scale value is: + +```text +scale_dense[i] = s[floor(i / 32)] +``` + +The CCE physical decomposition maps each dense lane into one of four f32 chunks. +For a chunk-local f32 lane `k`: + +```text +dense lane = 4 * k + delta +delta in {0, 1, 2, 3} +``` + +Then: + +```text +floor((4 * k + delta) / 32) = floor(k / 8) +``` + +So every f32 chunk needs: + +```text +scale_chunk[k] = s[floor(k / 8)] +``` + +`E2B_B16` plus `vcvt PART_EVEN` gives: + +```text +e2b_b16[j] = s[floor(j / 16)] for j = 0..127 +scale_f32[k] = e2b_b16[2 * k] + = s[floor((2 * k) / 16)] + = s[floor(k / 8)] +``` + +That matches the required `scale_chunk[k]` for all four f32 chunks. + +## 8. Recommendation + +Prefer adding a target-agnostic VMI `group_broadcast_load` logical memory op if +we want to make this optimization robust and local. The op should mean: + +```text +load one source value per logical group, then broadcast that value to every lane +in the group. +``` + +It must not mean `E2B`. `E2B_B16` is only one possible lowering when the +assigned layout is compatible. + +The unfused logical IR remains valid: + +```text +group_slot_load + group_broadcast +``` + +but a canonicalization/layout-prep pass may fuse it to: + +```text +group_broadcast_load +``` + +when the group-slot value has no separate semantic users. + +Then implement E2B support in phases: + +```text +1. Ensure the example PTO uses the correct logical semantics: + bf16 scale, num_groups = 8, dense 8 -> 256 broadcast. + +2. Add group_broadcast_load as a logical VMI memory op, plus canonicalization + from group_slot_load + group_broadcast when legal. + +3. Add a compatible-layout lowering shortcut: + when layout assignment already exposes an E2B-compatible chunk, lower the + group_broadcast_load chunk with vlds E2B_B16. + +4. Add an optional target-specific quant layout optimization: + recognize the whole dense quant subgraph and select the DINTLV/E2B plan when + it is legal and profitable. +``` + +This keeps VMI logical semantics independent from physical layout, while still +leaving a clear path to recover the CCE optimization automatically. + +## 9. Generalized E2B Broadcast Optimization + +The scale case above is one instance of a broader rule: E2B is a physical +materialization primitive for a packet of repeated group slots. It is not tied +to MX quant, but its legality depends on the physical chunk layout and on the +load distribution's carrier element width. + +### 9.1 E2B As A Packet Primitive + +For the verified `B16` case: + +```text +E2B_B16 packet: + source slots per packet = 8 + destination lanes per packet = 128 b16 lanes + repeat per source slot = 16 b16 lanes + +dst[lane] = src[base_slot + floor(lane / 16)] +``` + +This can materialize a physical chunk that needs: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +The optimization should reason in terms of physical chunks: + +```text +logical group_broadcast + source group slot for logical lane i = floor(i / logical_group_size) + +assigned physical layout + maps physical chunk lane l to logical lane i(l) + +E2B-compatible chunk + floor(i(l) / logical_group_size) = base_slot + floor(l / 16) +``` + +If this equality holds for a b16 physical chunk, the chunk can be loaded with +`E2B_B16` instead of materializing the broadcast with `vselr`. + +### 9.2 Direct 1 -> 16 + +A logical `1 -> 16` b16 broadcast is directly compatible with one E2B group: + +```text +s0 x16 +``` + +However, `E2B_B16` is naturally an 8-group packet: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +So a single `1 -> 16` use may lower to E2B only under one of these conditions: + +```text +packed case: + the compiler can pack eight independent 1 -> 16 broadcasts into one E2B load. + +partial-live case: + only one 16-lane group is live, and the target semantics prove inactive E2B + groups do not require valid source memory or can be safely over-read. + +full-packet case: + the logical IR actually contains eight adjacent groups, even if the current + consumer observes only one group through a layout/mask. +``` + +If these conditions are not proven, `BRC_B16`, `vdup`, or the existing generic +broadcast lowering is safer than E2B. In particular, do not introduce an E2B +load that reads seven extra source values unless the memory safety rule is +explicit. + +### 9.3 1 -> 32 Via Deinterleaved Reuse + +A dense logical `1 -> 32` b16 broadcast does not fit one E2B group in a single +contiguous physical chunk: + +```text +logical: s0 x32 +E2B group: s0 x16 +``` + +It becomes E2B-compatible when the assigned physical layout splits those 32 +logical lanes into two 16-lane physical uses: + +```text +physical use A: s0 x16 +physical use B: s0 x16 +``` + +This split can use the existing DINTLV-like element-parity layout: + +```text +#pto.vmi.layout +``` + +For logical lanes `0..31`, this maps: + +```text +even lanes 0,2,...,30 -> physical part 0 lanes 0..15 +odd lanes 1,3,...,31 -> physical part 1 lanes 0..15 +``` + +Because all 32 logical lanes carry the same `s0`, each part still sees +`s0 x16`. The lowering rule should check the resulting group index function, +not invent a new layout spelling. + +Then the compiler has two valid strategies: + +```text +reuse: + materialize one E2B group/chunk and map both physical uses to the same value. + +duplicate: + materialize the same E2B group twice if reuse would violate scheduling, + lifetime, or destructive-update constraints. +``` + +This is the mechanism behind the MX quant scale case: + +```text +dense logical scale: 8 groups, each x32 +physical f16/bf16 streams: each group appears as x16 per stream +``` + +The optimization is legal only if the 32 logical lanes are split by layout. It +is not legal as a direct E2B chunk load for a contiguous physical chunk that +genuinely needs `s0 x32` inside one chunk; that would require a separate +duplicate/interleave/concat materialization. + +### 9.4 N -> N * 16 And N -> N * 32 + +For b16 group broadcasts with consecutive slots and unit source stride: + +```text +N -> N * 16 +``` + +can be lowered by E2B in packets of 8 groups when the physical chunk sees the +groups in E2B order: + +```text +for base_slot in 0, 8, 16, ... + load src[base_slot : base_slot + 8] with E2B_B16 +``` + +Tail packets require either a proven safe masked/partial E2B form or a generic +fallback. + +For: + +```text +N -> N * 32 +``` + +E2B is profitable when the assigned layout decomposes each 32-lane logical group +into two 16-lane physical uses. That assigned layout may be the +`group_broadcast` producer's natural layout, or it may be required by a +downstream consumer. The lowering then reuses or duplicates the corresponding +E2B materialization for those two uses. This rule extends to: + +```text +N -> N * (16 * F) +``` + +when the layout decomposes each logical group into `F` physical 16-lane uses. + +### 9.5 Type Generalization + +E2B is a carrier-width load distribution. For `E2B_B16`, the load itself is +valid for 16-bit carriers: + +```text +bf16 +f16 +ui16 / si16 payloads +other 16-bit bit patterns whose consumers preserve the intended interpretation +``` + +The optimization must keep type interpretation outside the load: + +```text +bf16 scale + extf to f32: + E2B_B16 may feed vcvt bf16 -> f32. + +f16 broadcast: + E2B_B16 may materialize repeated f16 lanes if the consumer expects f16. + +ui16 payload later bitcast to bf16: + E2B_B16 may load the ui16 carrier, but the later bitcast/interpretation must + remain explicit in VMI or in the selected lowering plan. +``` + +Do not infer a floating-point type from E2B itself. `E2B_B16` only says how UB +bytes are placed into b16 lanes. + +`E2B_B32` is the b32 member of the same distribution family. The VPTO verifier +accepts `E2B_B32`, the ISA docs list E2B for `b16` and `b32`, and CCE quant code +uses `E2B_B32` in FP32 paths. It follows the same 8-source-slot packet rule: + +```text +E2B_B16: 8 source slots * 16 lanes/slot = 128 b16 lanes +E2B_B32: 8 source slots * 8 lanes/slot = 64 b32 lanes +``` + +The implemented E2B broadcast optimization therefore supports: + +```text +b16 contiguous: logical 1 -> 16 +b16 deinterleaved=2: logical 1 -> 32 +b32 contiguous: logical 1 -> 8 +b32 deinterleaved=2: logical 1 -> 16 +``` + +There is no `E2B_B8` in the documented load distribution family, so b8 +broadcasts should use other distributions or generic materialization. + +### 9.6 Broadcast Generalization + +E2B can optimize `pto.vmi.group_broadcast` when all of these are true: + +```text +source: + group slots come from consecutive memory slots + source_group_stride = 1 + slot type matches the E2B carrier width + +broadcast: + each physical chunk needs a run-length pattern compatible with the E2B repeat + count for that carrier width + +layout: + the run-length pattern is visible in the assigned layout before vmi-to-vpto + +uses: + rematerializing or reusing the E2B packet does not change observable memory or + arithmetic semantics +``` + +E2B is generally not the right primitive for ordinary scalar `pto.vmi.broadcast` +unless the scalar value is already in memory as an E2B packet or the compiler can +pack several independent scalar broadcasts into one E2B load. For a scalar +stored once in memory and needed in every lane, `BRC_B16/B32`, `BRC_BLK`, or a +register `vdup` is usually the more direct representation. + +### 9.7 Implementation Shape + +The recommended implementation order is: + +```text +1. Keep VMI semantics canonical: + group_slot_load + group_broadcast is the desugared meaning. + +2. Optionally canonicalize to group_broadcast_load: + this keeps memory source and broadcast semantics in one local op. + +3. Add an E2B compatibility query over assigned physical chunks: + given source slots, result layout, carrier width, and live lanes, answer + whether a chunk's group-index function is E2B-shaped. + +4. Lower compatible chunks to E2B packets: + generate one E2B load per needed packet, or reuse an existing packet when + multiple physical uses require identical contents. + +5. Add a later target-specific layout optimizer: + it may choose layouts that expose E2B-compatible chunks, but only by + rewriting/annotating layout-assigned VMI before vmi-to-vpto. +``` + +The compatibility query should return a reason when it rejects a candidate: + +```text +non-unit source stride +non-consecutive group slots +unsupported carrier width +tail packet lacks safe partial E2B semantics +physical lane mapping is not E2B-shaped +extra source memory read would be unsafe +consumer observes a different dense layout +``` + +This keeps the optimization auditable and prevents E2B from becoming an implicit +layout-changing peephole. + +## 10. Recognition, Solidification, Propagation, Lowering + +This section describes how an implementation should carry the optimization from +canonical VMI to VPTO without making `vmi-to-vpto` rediscover hidden context. + +### 10.1 Recognize Information + +Run recognition after hard layout assignment, when every relevant value already +has an explicit layout. + +Recognize the source shape: + +```text +%slots = pto.vmi.group_slot_load %base[%off], %stride {num_groups = G} +%bcast = pto.vmi.group_broadcast %slots {num_groups = G} +``` + +or the already-fused form: + +```text +%bcast = pto.vmi.group_broadcast_load %base[%off], %stride {num_groups = G} +``` + +Collect candidate facts: + +```text +source memory: + base pointer + offset + source_group_stride + element carrier width + memory element type + +logical broadcast: + num_groups = G + logical lanes = N + logical group size S = N / G + +assigned result layout: + physical arity + physical lanes per chunk + logical lane mapped to each physical lane + +uses: + whether the broadcast feeds elementwise ops, extf/truncf, stores, or multiple + independent consumers +``` + +Then compute an E2B packet plan per physical chunk. For `E2B_B16`, a physical +chunk is compatible when: + +```text +group_index_for_physical_lane(l) = base_slot + floor(l / 16) +``` + +for all live lanes in that chunk. + +Reject the candidate if: + +```text +source_group_stride != 1 +source slots are not consecutive +carrier width is unsupported +the assigned layout does not produce E2B-shaped chunks +tail/partial packet would read memory that is not proven valid +the group_slot_load has other non-rematerializable users +``` + +This recognition is an analysis step. It must not silently change layouts. + +### 10.2 Solidify Information + +`vmi-to-vpto` should not have to look at an arbitrary +`group_slot_load -> group_broadcast` use-def chain and decide to suppress one +load while replacing another op with E2B. The optimization pass must solidify +the decision in the layout-assigned IR. + +The preferred solidification is a target-agnostic logical memory op: + +```text +%bcast = pto.vmi.group_broadcast_load %base[%off], %stride {num_groups = G} + : !pto.ptr -> !pto.vmi.vreg +``` + +Semantic definition: + +```text +group_size = N / G +for logical lane i: + group = floor(i / group_size) + result[i] = base[off + group * stride] +``` + +This op is not target-specific and does not promise E2B. It is exactly the +fused logical form of: + +```text +%slots = pto.vmi.group_slot_load %base[%off], %stride {num_groups = G} +%bcast = pto.vmi.group_broadcast %slots {num_groups = G} +``` + +The fused op makes lowering local because the memory source, stride, group count, +result type, and assigned layout are all available on one op. A generic lowering +can still materialize it with `vsldb + vselr`; an optimized lowering may choose +`E2B_B16` for compatible physical chunks. + +The current implementation is intentionally narrower: because +`group_broadcast_load` does not yet have a generic `vsldb + vselr` lowering, +layout assignment fuses `group_slot_load + group_broadcast` only when the fused +op is already an E2B-compatible b16 candidate. Non-E2B shapes stay in the +unfused form and continue to use the existing `group_slot_load` plus +`group_broadcast` lowering path. + +Canonicalization rules: + +```text +group_slot_load + group_broadcast -> group_broadcast_load + when the group_slot_load has exactly that broadcast use, or when cloning the + load is legal and profitable for that use. + +group_broadcast_load -> group_slot_load + group_broadcast + remains a valid conceptual expansion for verification, documentation, and + generic fallback reasoning. +``` + +Solidification must preserve semantics for multi-use values: + +```text +if all uses consume only the broadcasted value: + replace with one shared group_broadcast_load. + +if only one use can benefit from the fused form: + clone/rematerialize that use-site load as group_broadcast_load and keep the + original group_slot_load for other users. + +if the group_slot_load itself has semantic group-slot users: + do not delete it; add a separate group_broadcast_load only if the extra memory + read is legal or if load cloning is otherwise proven safe. +``` + +### 10.3 Propagate Information + +After solidification, propagation should use ordinary VMI layout rules whenever +possible: + +```text +elementwise ops: + preserve the assigned layout when operands agree. + +ensure_layout: + makes layout transitions explicit when one use needs E2B-compatible chunks and + another use needs a different layout. + +rematerialization: + may clone group_broadcast_load per use-site instead of forcing a single layout + for all consumers. +``` + +For casts, propagation may need a targeted rule. The important MX quant case is: + +```text +E2B_B16 gives: + s0 x16, s1 x16, ..., s7 x16 + +bf16 -> f32 PART_EVEN gives: + s0 x8, s1 x8, ..., s7 x8 +``` + +If multiple f32 physical chunks require that same `s0 x8 ... s7 x8` pattern, +the post-assignment plan may mark them as the same rematerialized value. The +lowerer can then generate one `vcvt PART_EVEN` and map several logical physical +chunks to the same VPTO value. + +This reuse fact must be derived from the assigned lane mapping and the E2B packet +plan. It must not rely on a later CSE pass accidentally proving the duplicate. + +### 10.4 Implement Lowering + +`vmi-to-vpto` should lower `group_broadcast_load` locally. It may choose E2B +only when the op's assigned layout and source facts produce an explicit +E2B-compatible packet plan. + +For each E2B packet: + +```text +1. compute the source pointer: + base + packet_base_slot + +2. emit: + pto.vlds {dist = "E2B_B16"} + +3. map the emitted VPTO value to the physical result chunk(s) recorded in the + group_broadcast_load packet plan. +``` + +For `1 -> 32` under `deinterleaved = 2, block_elems = 1`: + +```text +logical group: + s0 x32 + +physical part 0: + s0 x16 + +physical part 1: + s0 x16 + +lowering: + emit one E2B packet if reuse is legal, or two identical E2B packets if + scheduling/lifetime constraints require duplication. +``` + +For MX quant scale after bf16->f32: + +```text +1. emit E2B_B16 for the b16 scale packet. +2. emit vcvt PART_EVEN to produce the f32 packet. +3. map that f32 packet to every physical f32 chunk whose lane mapping requires + s0 x8, s1 x8, ..., s7 x8. +4. lower mulf normally using the assigned physical chunks. +``` + +### 10.5 Where Layout Choices Happen + +There are three levels of optimization: + +```text +level 0: no E2B + canonical group_broadcast lowers through generic vselr materialization. + +level 1: E2B for already-compatible layouts + recognition sees the assigned layout is E2B-shaped and solidifies an E2B + materialization. + +level 2: choose E2B-compatible layouts + an optional layout optimization changes/rematerializes layouts before + recognition, for example selecting deinterleaved=2/block_elems=1 for a + broadcast use when all consumers can accept that layout. +``` + +The full CCE-like optimization for `ComputeY1ToFP8` is level 2: + +```text +x path: + select DINTLV-compatible layout for the dense x load/cast path. + +scale path: + select an E2B-compatible broadcast materialization. + +compute path: + keep mul/trunc/store in the selected physical chunk layout or insert explicit + layout materialization where required. +``` + +### 10.6 Test Plan + +Add focused tests in phases: + +```text +positive: + bf16 group_slot_load stride=1 + group_broadcast 8->256 assigned to + deinterleaved=2/block_elems=1 lowers scale chunks with E2B_B16. + +positive: + f16 1->16 or packed 8*(1->16) lowers to E2B only when source memory safety is + proven by full packet or supported partial semantics. + +positive: + 1->32 assigned to deinterleaved=2/block_elems=1 maps two physical uses to one + E2B packet or to two explicit duplicate packets. + +positive: + f32 1->8 lowers to E2B_B32, and f32 1->16 under deinterleaved=2/block_elems=1 + maps two physical uses to one E2B_B32 packet. + +negative: + source_group_stride != 1 falls back or diagnoses the E2B optimization. + +negative: + non-E2B-shaped assigned layout falls back to generic group_broadcast lowering. + +negative: + partial packet without proven safe memory read does not emit E2B. + +deferred: + E2B_B32 remains disabled until simulator/spec tests confirm the exact lane + mapping. +``` diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md new file mode 100644 index 0000000000..98cd3e9ee5 --- /dev/null +++ b/docs/designs/vmi-implementation-manual.md @@ -0,0 +1,4615 @@ +# VMI 实现手册 + +本文配套 `docs/designs/vmi-introduction.md` 和当前 VMI lowering 设计,回答 +“按什么顺序改哪些文件、每一步做到什么程度才算完成”。 + +本文不替代最终 ODS / C++ verifier / lit 测试。实现时如果发现本文和 ODS 或 verifier 冲突,以 +更精确的 verifier 约束为准,并同步刷新本文。 + +## 0. 当前仓库约束 + +当前仓库只有一个 MLIR dialect: + +```text +dialect name: pto +cpp namespace: ::mlir::pto +``` + +VPTO 低层 op/type 也在同一个 `pto` dialect 里,通过 `VPTOOps.td`、`VPTOTypeDefs.td` 等文件组织。 +因此第一版 VMI 不新建独立 dialect,采用同一 dialect 下的嵌套 mnemonic: + +```text +types: + !pto.vmi.vreg<...> + !pto.vmi.mask<...> + +attrs: + #pto.vmi.layout<...> + +ops: + pto.vmi.addf + pto.vmi.subf + pto.vmi.mulf + pto.vmi.ensure_layout +``` + +落地方式是:`PTO_Dialect` 仍是唯一 dialect,VMI 只是 `pto` dialect 内的一组 type/attr/op。 +如果后续要拆成真正独立的 `pto.vmi` dialect,必须先保证所有 pass、type converter、parser 测试 +和公开文档同步迁移;第一版不要做这个拆分。 + +风险点:带点 mnemonic 例如 `vmi.vreg`、`vmi.addf` 必须在 Slice 0 先用 parser round-trip 测试 +证明。如果 TableGen 的默认 type/attr parser 不接受该 spelling,就在 VMI type/attr 上实现 +custom assembly format,而不是改公开 spelling。 + +## 1. 文件布局 + +新增文件: + +```text +include/PTO/IR/VMIAttrs.td +include/PTO/IR/VMITypeDefs.td +include/PTO/IR/VMIOps.td +lib/PTO/IR/VMI.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +lib/PTO/Transforms/PTOValidateVMIIR.cpp +test/lit/vmi/ +``` + +修改文件: + +```text +include/PTO/IR/PTOAttrs.td +include/PTO/IR/PTOTypeDefs.td +include/PTO/IR/PTOOps.td +include/PTO/IR/CMakeLists.txt +lib/PTO/IR/CMakeLists.txt +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/CMakeLists.txt +``` + +推荐 include 关系: + +```text +PTOAttrs.td + include "PTO/IR/VMIAttrs.td" + +PTOTypeDefs.td + include "PTO/IR/VMITypeDefs.td" + +PTOOps.td + include "PTO/IR/VMIOps.td" +``` + +放置顺序: + +```text +VMIAttrs.td: + include PTODialect.td, AttrTypeBase.td, EnumAttr.td + must not include PTOAttrs.td + +VMITypeDefs.td: + include PTODialect.td and can rely on PTOAttrs.td having included VMIAttrs.td + +VMIOps.td: + include after PTO_Op is defined in PTOOps.td + do not include VPTOOps.td from VMIOps.td +``` + +这样现有 `LLVM_TARGET_DEFINITIONS PTOOps.td` 的 TableGen 生成路径可以继续覆盖 VMI type、attr +和 op。只有当 TableGen 生成目标不能正确收集新增 td 时,才单独新增 `mlir_tablegen` 目标。 + +`lib/PTO/IR/VMI.cpp` 放 VMI type/attr/op verifier、parse/print helper 和公共 lane-map helper。 +不要把 VMI verifier 塞进 `VPTO.cpp`。 + +Pass 注册要求: + +```text +include/PTO/Transforms/Passes.td: + add VMILayoutAssignment + add VMIToVPTO + add PTOValidateVMIIR + +include/PTO/Transforms/Passes.h: + add explicit create*Pass declarations if generated declarations are not enough + +lib/PTO/Transforms/CMakeLists.txt: + add the three new .cpp files to PTOTransforms + keep DEPENDS PTOPassesIncGen and PTOOpsIncGen + add missing MLIR dialect libraries only when a new source actually includes them +``` + +Driver wiring is explicit and opt-in. `ptoas --enable-vmi` runs the VMI semantic pipeline before the VPTO backend +pipeline: + +```text +pto-validate-vmi-ir +vmi-layout-assignment +canonicalize/cse +vmi-layout-fold +canonicalize/cse +vmi-layout-rematerialize +canonicalize/cse +vmi-layout-sink-materialization +canonicalize/cse +vmi-legalize-arith-select +pto-validate-vmi-layout-ir +vmi-to-vpto +canonicalize/cse +``` + +`--enable-vmi` requires `--pto-backend=vpto` or `pto.backend = "vpto"` because the pipeline produces physical VPTO +values and ops. It is not part of the default PTOAS pipeline; existing PTO/VPTO inputs keep their previous behavior +unless the flag is set. + +The `ptoas --enable-vmi` user-facing entry also rejects public functions whose signature contains `!pto.vmi.*`. +Internal/private VMI-typed functions are materialized at explicit boundary +helpers by baseline `vmi-layout-assignment` and physicalized by `vmi-to-vpto`. +A later optimization pass may specialize private signatures. A public VMI ABI +requires an explicit materialization plan and must not be inferred from the +layout solver. + +CLI coverage: + +```text +vmi_ptoas_cli_pipeline.pto: + --pto-backend=vpto + --enable-vmi lowers the VMI pipeline + pto.backend = "vpto" also selects the VPTO-compatible path + explicit --pto-backend=emitc with --enable-vmi is rejected + f16->f32 store lowers through the fold-consumers path, proving the driver + uses the optimized pipeline rather than only the hard skeleton + +vmi_ptoas_backend_required_invalid.pto: + default emitc backend with --enable-vmi and no pto.backend = "vpto" is rejected + +vmi_ptoas_public_abi_invalid.pto / vmi_ptoas_public_result_abi_invalid.pto: + public VMI argument/result signatures are rejected before layout assignment +``` + +## MLIR Framework Usage + +三个 correctness stage 和若干 layout optimization pass 不应该用同一种 MLIR 机制硬套。 +这里先定义实现框架选择,避免后续把 layout 求解、优化重写、结构化控制流改写和 1:N +physicalization 混在一个 pattern pass 里。 + +当前实现框架按下面的职责切开: + +```text +pto-validate-vmi-ir: + Operation::walk verifier。只看 IR 是否满足阶段不变量,不改 IR,不使用 conversion framework。 + +vmi-layout-assignment: + module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, + 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 + +vmi-layout-fold / vmi-layout-rematerialize / vmi-layout-sink-materialization: + legal-to-legal VMI optimization passes。它们只消费 layout-assigned VMI IR,并继续产出 + layout-assigned VMI IR;所有新选择必须体现在 current op、type 或 helper IR 中。 + +vmi-legalize-arith-select: + canonicalize 之后的 hygiene pass。它把 scalar-condition arith.select with VMI result + 恢复成 VMI pipeline 可控的结构化控制流形态。 + +vmi-to-vpto: + MLIR OneToNTypeConversion。每个 layout-assigned VMI value 按统一 physical ordering 展开成多个 + VPTO value,并依靠 OneToN structural patterns 重写函数、return、region result、block argument 和 + branch operand。 +``` + +这三个 pass 的边界必须通过 IR 可见状态传递:layout 写在 `!pto.vmi.*` type 上,必要 materialization +写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 +`unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 + +源码级实现应该进一步拆成七个独立层次: + +```text +IR layer: + include/PTO/IR/VMIAttrs.td + include/PTO/IR/VMITypeDefs.td + include/PTO/IR/VMIOps.td + lib/PTO/IR/VMI.cpp + + 只定义语义、parse/print、type/op verifier 和公共 lane-map helper。 + 这一层不能知道 layout assignment 的全局选择,也不能直接依赖 VPTO lowering pass。 + +Semantic validation layer: + lib/PTO/Transforms/PTOValidateVMIIR.cpp + + 只检查阶段输入/输出是否满足 contract。它是 hard gate,不做 repair。 + +Layout solving layer: + lib/PTO/Transforms/VMILayoutAssignment.cpp + + 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, + 然后把结果写回 type 或 ensure_* helper。 + +Layout support query layer: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + + 只放跨阶段共享的纯查询:cast layout fact、group_reduce layout fact、 + ensure_* materialization support、layout-aware store support 等。它可以被 + assignment、validation、layout optimization 和 vmi-to-vpto 调用,但不能保存 + per-value 状态,不能返回 VPTO 指令计划,不能决定 clone/rematerialize,也不能 + 通过 producer/user/control-flow context 恢复 lowering 决策。 + + 加新 query 的标准是:至少两个阶段需要同一个语义事实,并且重复实现会导致 + assignment、validation、lowering 对同一个 layout shape 得出不同结论。只有 + 一个 lowering pattern 自己使用的分支应该留在该 pattern 内。 + +Layout optimization layer: + lib/PTO/Transforms/VMILayoutFold.cpp + lib/PTO/Transforms/VMILayoutRematerialize.cpp + lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + lib/PTO/Transforms/VMILegalizeArithSelect.cpp + + 负责在 layout-assigned VMI IR 内做 legal-to-legal 改写。它可以让公共 canonicalize/cse + 协助清理和合并 IR,但不能把决策藏到 side table 里。 + +Physicalization layer: + lib/PTO/Transforms/VMIToVPTO.cpp + + 负责把 layout-assigned VMI value 通过 OneToNTypeConversion 展成 VPTO physical values, + 并把每个 pto.vmi.* semantic op 改写成 VPTO op 序列。 + +Driver/test layer: + tools/ptoas/ptoas.cpp + tools/pto-test-opt/ + test/lit/vmi/ + + ptoas 只暴露 opt-in pipeline;pto-test-opt 保留单 pass 和中间 IR 的调试入口。 +``` + +每层的 MLIR 框架选择如下: + +```text +ODS/TableGen: + 定义 type/attr/op surface 和 verifier hook。 + +Operation::walk: + 用于 validation 和 layout constraint collection。 + +Union-find + DenseMap: + 用于 layout assignment 的 per-SSA-value 等价类求解。 + +IRRewriter/RewriterBase: + 用于 layout assignment 之后的 type rewrite、helper insertion;cheap producer + rematerialization 属于后续 layout optimization pass。 + +OneToNTypeConverter + OneToNOpConversionPattern: + 只用于 vmi-to-vpto,把一个 logical VMI value 展成多个 VPTO value。 + +Upstream OneToN structural helpers: + func.func / func.call / func.return / common SCF region-result conversion。 + +Project-local OneToN structural patterns: + cf.br / cf.cond_br / cf.switch / scf.execute_region / scf.index_switch。 +``` + +不要把这些层次合并成一个万能 pattern pass。特别是: + +```text +layout assignment 不能依赖 OneToNTypeConverter: + 因为 layout 不是 type-only 决策,同一个 !pto.vmi.vreg<128xf32> 的不同 SSA value + 可能因 producer/consumer/control-flow 约束得到不同 layout。 + +vmi-to-vpto 不能重新做 layout solving: + 它只消费已经写在 type/helper 上的 layout 决策。遇到未 assignment 的 VMI type 必须失败。 + +structural OneToN pattern 不能知道 VMI 语义: + 它们只负责 flatten/rebuild operands、results、successor operands 和 block arguments。 + 具体 lane 语义只属于 pto.vmi.* op lowering pattern。 + +verifier 不能偷偷修 IR: + 否则后续 pass 会依赖 verifier 的隐式 repair 行为,导致 pipeline 顺序不可推理。 +``` + +一个可以直接对照代码的 pass 边界表: + +```text +pass input output +--------------------------- ---------------------------- ---------------------------- +pto-validate-vmi-ir surface VMI IR same IR, or hard failure +vmi-layout-assignment surface/layout-partial VMI layout-assigned VMI IR +layout optimization passes layout-assigned VMI IR layout-assigned VMI IR +vmi-legalize-arith-select layout-assigned VMI IR layout-assigned VMI IR +pto-validate-vmi-layout-ir layout-assigned VMI IR same IR, or hard failure +vmi-to-vpto layout-assigned VMI IR physical VPTO IR +final residual verifier physical VPTO candidate no pto.vmi.*, no !pto.vmi.* +``` + +### 代码级落点 + +当前实现应该能按文件直接审计。每个 pass 的核心类、MLIR 机制和失败边界如下: + +```text +lib/PTO/Transforms/PTOValidateVMIIR.cpp + pass: + PTOValidateVMIIRPass + PTOValidateVMILayoutIRPass + public helpers: + validateVMIProducerBoundaryIR + validateVMILayoutAssignedIR + MLIR API: + Operation::walk + func::FuncOp function type inspection + recursive TypeAttr / TypedAttr / ArrayAttr / DictionaryAttr scan + must not: + rewrite IR + create unrealized_conversion_cast + create ConversionTarget + repair illegal helper/type leakage + +lib/PTO/Transforms/VMILayoutAssignment.cpp + pass: + VMILayoutAssignmentPass + core object: + LayoutSolver + state: + DenseMap + SmallVector + SmallVector + SmallVector + SmallVector + MLIR API: + Operation::walk for fact collection + SymbolTable for direct internal calls + concrete cf/scf handlers for control-flow equivalence + IRRewriter/OpBuilder only after solving + must not: + use TypeConverter as the layout decision model + rewrite while collecting constraints + hide chosen layout in a pass-private side table + infer external VMI ABI + +lib/PTO/Transforms/VMILayoutFold.cpp +lib/PTO/Transforms/VMILayoutRematerialize.cpp +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + pass: + VMILayoutFoldPass + VMILayoutRematerializePass + VMILayoutSinkMaterializationPass + VMILegalizeArithSelectPass + role: + legal-to-legal layout-assigned VMI optimization and hygiene + MLIR API: + Operation::walk for local discovery + OpBuilder/RewriterBase for explicit IR rewrites + canonicalize/cse between passes for cleanup and deduplication + must not: + introduce physical VPTO register types + require vmi-to-vpto to inspect producers, users, or CFG + preserve optimization decisions outside IR + +lib/PTO/Transforms/VMIToVPTO.cpp + pass: + VMIToVPTOPass + converter: + VMIToVPTOTypeConverter : OneToNTypeConverter + pattern families: + OneToNOpConversionPattern for pto.vmi.* semantic ops + upstream func/scf OneToN structural patterns + project-local cf/scf structural OneToN patterns + MLIR API: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + applyPartialOneToNConversion + final residual walk + must not: + redo layout solving + inspect defining ops to recover physical parts + allow pto.vmi.pack/unpack/ensure_* to survive final output + allow unrealized_conversion_cast to survive final output +``` + +这里最重要的分界是:`vmi-layout-assignment` 解决的是 value-level layout,`vmi-to-vpto` +解决的是 type/value 1:N physicalization。前者的结果必须已经写回 `!pto.vmi.*` type 或显式 +`pto.vmi.ensure_*`;后者只能消费这些 IR-visible facts。 + +这也回答了“有没有充分利用 MLIR 自带能力”:结构化 1:N signature/control-flow conversion 必须用 +MLIR OneToN conversion;layout assignment 则不能强行塞进 converter,因为 converter 看不到 +producer natural layout、consumer request、CFG join 和 call-return slot 这些 value-level facts。 + +### Pass 级实现细则 + +这几个 pass 对 MLIR 自带能力的使用方式应该是“各用其长”,而不是都套成 converter pattern。 +实现时按下面的判断标准拆: + +```text +只检查阶段不变量: + 用 Operation::walk。不要创建 ConversionTarget,也不要 rewrite。 + +需要根据 SSA value、CFG join、call boundary 和 consumer request 决策 layout: + 用 module-level solver。MLIR conversion framework 没有 per-value layout 决策模型。 + +需要把一个 logical value 展成多个 physical value,并同步改 function/block/control-flow signature: + 用 OneToNTypeConversion。这里是 converter framework 最应该发挥作用的地方。 +``` + +#### Pass 框架细化 + +第一版实现按下面的源码和 MLIR infra 对齐。这个表是实现时的边界,不只是文档分层: + +```text +source file pass primary MLIR facility +----------------------------------------- --------------------------- --------------------------------------------- +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter +lib/PTO/Transforms/VMILayoutFold.cpp + vmi-layout-fold Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutRematerialize.cpp + vmi-layout-rematerialize Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + vmi-layout-sink-materialization + Pattern-free local IR rewrite +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + vmi-legalize-arith-select Operation::walk + OpBuilder rewrite +lib/PTO/Transforms/VMIToVPTO.cpp vmi-to-vpto OneToNTypeConverter + OneToNOpConversionPattern +``` + +这意味着每个 pass 的输入输出 contract 是固定的: + +```text +pto-validate-vmi-ir: + input: + surface VMI IR + legal: + pto.vmi semantic ops + !pto.vmi.vreg + !pto.vmi.mask + func/scf/cf structural ops carrying those types + illegal: + layout-assigned !pto.vmi.* type + physical !pto.vreg / !pto.mask / !pto.align type + pto.vmi.ensure_* / pack / unpack helper + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-layout-assignment: + input: + verifier-clean surface VMI IR + legal work: + solve per-SSA layout/granularity constraints + rewrite VMI value/function/block types with explicit layout + insert pto.vmi.ensure_* only for use-site materialization + rematerialize cheap producers instead of inserting ensure_* when semantics are replay-safe + illegal work: + physicalize to !pto.vreg / !pto.mask + introduce pto.vmi.pack / pto.vmi.unpack + keep layout only in a pass-private side table + output: + layout-assigned VMI IR, or failure + +pto-validate-vmi-layout-ir: + input: + layout-assigned VMI IR + legal: + pto.vmi semantic ops + pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity + !pto.vmi.vreg + !pto.vmi.mask + illegal: + surface !pto.vmi.vreg + surface !pto.vmi.mask + physical VPTO register types before vmi-to-vpto + pto.vmi.pack / pto.vmi.unpack + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-to-vpto: + input: + layout-assigned VMI IR + legal work: + convert each VMI value to an ordered list of physical VPTO values + rewrite function signatures, block arguments, branch operands, region results and calls + lower pto.vmi semantic/helper ops to VPTO ops + illegal work: + infer missing layouts + change a chosen layout because one pattern finds a cheaper lowering + leave pto.vmi.* / !pto.vmi.* / unrealized_conversion_cast in final IR + output: + physical VPTO IR, or failure +``` + +`vmi-layout-assignment` 和 `vmi-to-vpto` 的关键差异是:前者解决“这个 SSA value 应该是什么 layout”, +后者解决“这个已经有 layout 的 SSA value 展开成哪些 physical value”。同一个 surface type 不能用 +`TypeConverter` 得到唯一答案: + +```mlir +%a = pto.vmi.broadcast %s : f32 -> !pto.vmi.vreg<128xf32> +%b = pto.vmi.extf %x : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%c = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32> +} +``` + +这里 `%a` 可以按 consumer 需要 rematerialize 成 contiguous 或 deinterleaved;`%b` 的 natural layout 是 +`deinterleaved=2`;`%c` 的 layout 必须由两个 yield 和后续 consumer 共同约束。这个选择依赖 Value、 +def-use、control-flow join 和 use-site request,不是 `!pto.vmi.vreg<128xf32> -> ...` 的 type-only 规则。 + +因此 layout pass 的代码形态应该固定为: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collectAllVMIValues())) + return failure(); + if (failed(collectEquivalenceConstraints())) + return failure(); + if (failed(collectProducerNaturalLayouts())) + return failure(); + if (failed(collectConsumerRequests())) + return failure(); + if (failed(rewriteDataTypes())) + return failure(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferAndRewriteMaskTypes())) + return failure(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionTypesFromSolvedValues(); + return validateVMILayoutAssignedIR(module); +} +``` + +其中 `collect*` 阶段只能记录事实,不能边 walk 边改 IR。原因是控制流和 call boundary 会把后面才遇到的 +operand/result 合并到前面的 value class;边收集边改 type 会让后续约束看到混合状态,错误诊断也会依赖 +walk 顺序。 + +`vmi-to-vpto` 则必须是 converter pass。第一版使用的是 `OneToNTypeConversion`,因为它要同时处理 +value type 和结构签名: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +func.func @f(%arg0: !pto.vmi.vreg<128xf32, layout>) -> !pto.vmi.vreg<128xf32, layout> + -> func.func @f(%arg0_0: !pto.vreg<64xf32>, %arg0_1: !pto.vreg<64xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这里不能用普通 1:1 `TypeConverter`,也不能靠每个 VMI op pattern 自己拆 operand。否则 `func.return`、 +`cf.br`、`scf.for` iter arg 这种没有 VMI defining op 的边界会漏转换。`OneToN` adaptor 才是 semantic +pattern 获取 physical parts 的唯一来源: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); +``` + +结构化转换的实现分工如下: + +```text +upstream helper: + populateFuncTypeConversionPatterns + covers func.func / func.return / direct func.call signature conversion + + scf::populateSCFStructuralOneToNTypeConversions + covers common SCF result/yield/block-argument structural conversions + +project-local OneToN patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只能做结构搬运: + +```text +1. read OneToNTypeMapping for each original operand/result +2. flatten successor operands or region result types +3. rebuild the same cf/scf op with converted types +4. inline/move original regions when required +``` + +它们不能做下面这些事: + +```text +infer layout from operand defining op +emit vadd/vcvt/vlds/vsts +decide contiguous vs deinterleaved +special-case pto.vmi semantic op +``` + +VMI 语义只能出现在 `OneToNOpConversionPattern` 里。这样才能保证 block argument、function +argument、loop-carried value 和 branch target argument 都按同一套 physical ordering 转换。 + +`vmi-to-vpto` 的 legality 由 preflight + conversion + final gate 三段组成,而不是单靠 +`ConversionTarget`: + +```text +preflight: + verifyVMIToVPTOInputIR + rejects layout-free VMI types + verifySupportedVMIToVPTOOps + rejects unsupported semantic/materialization cases before rewrite starts + +conversion: + applyPartialOneToNConversion + applies structural and semantic OneToN patterns + +final gate: + verifyNoResidualVMIIR + rejects pto.vmi.* + rejects !pto.vmi.* in operand/result/block/function/attribute type trees + rejects pto.vmi.pack/unpack materialization helpers + rejects unrealized_conversion_cast +``` + +这比只设置 `ConversionTarget` 更直接,因为当前 OneToN 工具链的重点是 type/value expansion 和 pattern +rewriter;最终合法性必须递归检查 attribute/type tree,防止 VMI type 被藏在 nested attr 里。 + +#### `pto-validate-vmi-ir` / `pto-validate-vmi-layout-ir` + +这两个 pass 是 hard gate,不是 legalization pass。 + +使用的 MLIR 能力: + +```text +Operation::walk: + 遍历 module 内所有 op、region、block argument、operand/result type 和 attribute。 + +TypeAttr / TypedAttr recursive scan: + 拒绝把 VMI/physical VPTO type 藏在 nested attribute 中。 + +func::FuncOp function type special case: + function_type attr 是签名本身,可以按当前阶段规则检查;其它 attr 不能携带 VMI/physical type。 +``` + +不使用 `ConversionTarget` 的原因: + +```text +ConversionTarget 适合表达“哪些 op/type legal,哪些 pattern 能改掉”。 +这里我们只想回答“当前 IR 是否已经处在某个阶段边界”,失败后必须停机,而不是尝试 repair。 +如果 verifier 顺手改 IR,pipeline 的阶段不变量会变成隐式行为,后续 pass 很难审计。 +``` + +这两个 pass 的输出只能是原 IR 或 failure: + +```cpp +void runOnOperation() override { + if (failed(verifyStageInvariant(getOperation()))) + signalPassFailure(); +} +``` + +#### `vmi-layout-assignment` + +这个 pass 使用 MLIR 的 IR 遍历和 rewrite 基础设施,但不使用 `TypeConverter` 作为主模型。 + +核心原因: + +```text +TypeConverter 的输入是 Type。 +layout assignment 的输入是 Value。 + +同一个 !pto.vmi.vreg<128xf32> 可以因为不同 producer/consumer 关系得到不同 layout: + f16->f32 widen result -> deinterleaved=2 + f8 ->f32 widen result -> deinterleaved=4 + only contiguous store value -> contiguous +``` + +实现应拆成两个阶段,不要边 walk 边 rewrite: + +```text +collect: + 1. 收集所有 VMI data/mask SSA value 和 block argument。 + 2. 用 union-find 合并必须同 layout 的 value。 + 3. 记录 producer natural layout。 + 4. 记录 consumer layout/granularity request。 + 5. 记录 function return slot、call operand/result、branch operand/block argument 关系。 + +rewrite: + 1. 为每个 equivalence class 选 layout。 + 2. 改写 value/function/block/result type。 + 3. 对 use-site mismatch 插入 ensure_* 或 rematerialize cheap producer。 + 4. 运行 pto-validate-vmi-layout-ir。 +``` + +建议的数据结构边界: + +```cpp +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +这里可以充分使用 MLIR 的接口,但它们只是 constraint source: + +```text +BranchOpInterface / concrete cf.* handlers: + successor operand[i] == destination block argument[i] + +RegionBranchOpInterface / concrete scf.* handlers: + region yield operand[i] == parent result[i] + loop init/result/iter_arg/yield 同 slot 等价 + +CallOpInterface + SymbolTable: + direct internal call operand/result 和 callee argument/return slot 等价 + external/indirect VMI call 先拒绝,因为缺 ABI materialization + +IRRewriter: + 只在 solve 完成后统一改 type、插 ensure_*、clone cheap producer。 +``` + +`vmi-layout-assignment` 的 pass invariant 是:所有 layout 决策必须写回 IR。后续 `vmi-to-vpto` +只能读取 `!pto.vmi.*` type 和显式 `pto.vmi.ensure_*`,不能依赖 layout solver 的 side table。 + +#### `vmi-to-vpto` + +这个 pass 应该充分使用 MLIR converter framework,具体是 `OneToNTypeConversion`,不是普通 +`DialectConversion`。 + +普通 1:1 dialect conversion 不够的地方: + +```text +!pto.vmi.vreg<128xf32, deinterleaved=2> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<256xf8, deinterleaved=4> + -> !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8> +``` + +函数参数、返回值、block argument、branch operand、region result 都必须做同样的 1:N 展开。 +这正是 `OneToNTypeConverter`、`OneToNOpConversionPattern` 和结构化 OneToN helper 的职责。 + +实现骨架: + +```cpp +void runOnOperation() override { + ModuleOp module = getOperation(); + + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(&getContext()); + + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + populateProjectLocalCFOneToNPatterns(typeConverter, patterns); + populateVMISemanticOneToNPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +`VMIToVPTOTypeConverter` 只做一种事:把 layout-assigned VMI type 映射到 canonical physical value list。 +它不能重新推导 layout。 + +```text +contiguous: + chunk0, chunk1, ... in logical order + +deinterleaved=2: + part0 chunks for logical lanes 0,2,4,... + part1 chunks for logical lanes 1,3,5,... + +deinterleaved=4: + part0 chunks for lanes 0,4,8,... + part1 chunks for lanes 1,5,9,... + part2 chunks for lanes 2,6,10,... + part3 chunks for lanes 3,7,11,... + +num_groups=G: + group-slot reduce result layout + physical storage is contiguous chunk order + only canonical group_slot(g) lanes contain semantic values +``` + +每个 semantic pattern 必须从 adaptor 拿 physical parts,不允许从 defining op 反推: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhs = adaptor.getLhs(); + ValueRange rhs = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhs.size() != rhs.size() || lhs.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, resultType] : llvm::enumerate(resultTypes)) { + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs[i], rhs[i]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这个约束对控制流是关键的:`scf.for` iter arg、branch target argument、function argument 都没有可用的 +defining op;它们的 physical parts 只能来自 OneToN signature/block argument conversion。 + +`vmi-to-vpto` 应有三层失败点,诊断不要混在一起: + +```text +preflight: + layout 未 assignment、unsupported semantic op、unsupported materialization path + +conversion: + pattern 缺失、arity mismatch、结构化控制流展开失败 + +final residual verifier: + 任何 pto.vmi.*、!pto.vmi.*、pto.vmi.pack/unpack/ensure_*、unrealized_conversion_cast 残留 +``` + +### `pto-validate-vmi-ir` + +`pto-validate-vmi-ir` 是边界 verifier,不使用 DialectConversion。 + +推荐使用: + +```text +Operation::walk +TypeSwitch / isa / dyn_cast +emitOpError / InFlightDiagnostic +SymbolTable, for function/call boundary checks +CallGraph or manual call graph collection, if recursive SCC needs diagnostics +DominanceInfo, if helper placement or resource dominance is checked +``` + +这个 pass 只检查 VMI producer boundary 和阶段不变量: + +```text +before layout assignment: + VMI data/mask values use surface type + no layout-assigned VMI type leaks in unless the test explicitly starts after assignment + no physical VPTO op appears in the semantic VMI region + no VMI helper op appears before the pass that is allowed to create it + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after layout assignment: + pass: pto-validate-vmi-layout-ir + every VMI data value has a layout + every VMI mask has layout and concrete granularity + control-flow joins have stable type/layout + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after VMI-to-VPTO: + no VMI op/type/helper remains + no unrealized_conversion_cast remains +``` + +不要把这个 pass 写成 rewrite pass。它可以收集 context 用于诊断,但不能通过局部修补让非法 IR +继续前进;否则后续 pass 会开始依赖 verifier 的隐式 repair 行为。 + +实现上要扫描的不只是 operand/result/block argument: + +```text +func.func function type: + 作为函数签名本身检查,允许出现当前阶段合法的 VMI type。 + +non-signature attributes: + module/op attribute 中只要递归包含 VMI type 或 physical VPTO type 都拒绝。这里包括 TypeAttr、 + TypedAttr,以及 ArrayAttr/DictionaryAttr 这类容器中的 nested attribute/type payload。 +``` + +这样可以堵住 hidden-state 形式的 side table,例如把 `!pto.vmi.vreg<...>` 偷存在 module attribute +里。`func.func` 的内建 `function_type` attr 是唯一例外,因为它只是函数签名的 MLIR 表达,不是额外 +隐藏状态。 + +### `vmi-layout-assignment` + +`vmi-layout-assignment` 不以 MLIR `TypeConverter` 作为主机制。 + +原因是 layout 选择不是单纯的 `Type -> TypeRange` 映射: + +```text +same surface type: + !pto.vmi.vreg<128xf32> + +possible per-value decisions: + value produced by f16->f32 widen: deinterleaved=2 + value loaded only for contiguous store: contiguous + value feeding fp8-like->f32 consumer path: deinterleaved=4 +``` + +两个 SSA value 可以有完全相同的 surface type,但因为 producer natural layout、consumer demand、 +控制流 join 和 target capability 不同,得到不同 layout。因此主模型应该是 per-SSA-value 的约束图, +而不是类型转换表。 + +推荐内部结构: + +```text +DenseMap +DenseMap +DenseMap +SmallVector +SmallVector +``` + +推荐使用的 MLIR 基础能力: + +```text +RegionBranchOpInterface: + collect scf.if/scf.for-like region entry, yield, result relations + +BranchOpInterface: + collect cf.br/cf.cond_br predecessor operand -> block argument relations + +CallOpInterface, CallableOpInterface, FunctionOpInterface: + collect call operand/result and function argument/result relations + +SymbolTable: + resolve direct calls and reject unresolved VMI signature assumptions + +DominanceInfo: + choose legal insertion points for ensure_layout, mask conversion, and rematerialization + +IRRewriter / RewriterBase: + rewrite types, insert helper ops, clone rematerializable producers +``` + +求解结果必须 materialize 回 IR,不能留在 side table: + +```text +1. Rewrite every VMI value type to a layout-assigned type. +2. Rewrite mask type to layout + b8/b16/b32 granularity. +3. Insert pto.vmi.ensure_layout where a consumer requires a different layout. +4. Insert pto.vmi.ensure_mask_layout / ensure_mask_granularity where predicate layout or granularity differs. +5. Clone rematerializable producers such as constant, broadcast, create_mask, iota-like producers when cheaper. +6. Re-run the VMI stage verifier. +``` + +这个 pass 可以用 `RewritePatternSet` 辅助局部 canonicalization,例如删除同 layout 的 +`ensure_layout`,但不能让 greedy pattern driver 决定全局 layout。全局约束必须先收敛,再做改写。 + +更具体地说,这里不用 `TypeConverter` 的原因不是 MLIR converter 不好用,而是此阶段的问题不是 +“一个旧 type 机械变成一个新 type”: + +```text +%a : !pto.vmi.vreg<128xf32> // 只被 contiguous store 消费 +%b : !pto.vmi.vreg<128xf32> // 来自 f16->f32 widen,后续继续 vadd +%c : !pto.vmi.vreg<128xf32> // 控制流 join,两个 predecessor 必须统一 layout +``` + +这三个 value 的 surface type 完全相同,但 layout 决策分别可能是 contiguous、deinterleaved=2、 +以及由 join 两侧约束共同决定。`TypeConverter` 看不到“这个 SSA value 的 producer/consumer/CFG +关系”,所以它只能作为后续 physicalization 的工具,不能作为 layout assignment 的主算法。 + +该 pass 对 MLIR 基础能力的使用边界是: + +```text +Operation::walk: + 收集所有 VMI SSA value、block argument、函数签名和 op transfer facts。 + +Union-find / DenseMap: + 表达必须同 layout 的 equivalence class。 + +SymbolTable: + 解析 direct internal func.call;带 VMI type 的 external/indirect call 先拒绝。 + +IRRewriter: + 改写 function/block/result type,插入 ensure_*。 + +verifyLayoutAssignedVMIIR: + pass 末尾 hard gate,确认所有决策已经 materialize 到 IR。 +``` + +### `vmi-to-vpto` + +`vmi-to-vpto` 应该使用 MLIR 的 1:N conversion framework,而不是普通 `DialectConversion`。 +这个 pass 的核心问题正是一个 logical VMI value physicalize 成多个 VPTO value: + +```text +!pto.vmi.vreg -> !pto.vreg... +!pto.vmi.mask -> !pto.mask... +``` + +普通 `DialectConversion` 的 `OpConversionPattern` 对 1:N fixed operand/result 支持不够直接: +pattern adaptor 可能拿到 source materialization,也可能拿到 flat converted operands;`func.return` +这类“一个 logical operand 展开成多个 physical operands”的场景也容易出现不完整展开。因此这里采用 +MLIR `OneToNTypeConversion` 工具: + +推荐组件: + +```text +OneToNTypeConverter +OneToNOpConversionPattern +OneToNPatternRewriter +OneToNTypeMapping +populateFuncTypeConversionPatterns +scf::populateSCFStructuralOneToNTypeConversions +applyPartialOneToNConversion +final residual verifier +``` + +`OneToNTypeConverter` 负责 layout-assigned VMI type 到 ordered physical VPTO value list: + +```cpp +typeConverter.addConversion([](VMIVRegType type, SmallVectorImpl &results) { + // Use getVMIPhysicalArity(type) and the shared lane-map helper. + // Append one physical !pto.vreg per part/chunk. +}); + +typeConverter.addConversion([](VMIMaskType type, SmallVectorImpl &results) { + // Use mask granularity and physical arity helper. + // Append one physical !pto.mask per part/chunk. +}); +``` + +source/target materialization 可以用 VMI helper 承接中间状态: + +```text +VMI value -> physical values: + pto.vmi.unpack + +physical values -> VMI value: + pto.vmi.pack +``` + +但它们只是 conversion materialization,不是最终 IR 的合法残留。final gate 必须拒绝: + +```text +pto.vmi.pack +pto.vmi.unpack +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +unrealized_conversion_cast +``` + +`applyPartialOneToNConversion` 本身不是 legality framework;它负责应用 1:N patterns 并替换内部 +`unrealized_conversion_cast`。因此 `vmi-to-vpto` 必须在 conversion 后运行 final residual verifier, +把下面这些全部作为 hard failure: + +```text +any pto.vmi.* op +any !pto.vmi.* type +any pto.vmi.pack/unpack materialization helper +any pto.vmi.ensure_* helper +any unrealized_conversion_cast +``` + +结构转换必须覆盖: + +```text +func arguments/results and return operands: + use populateFuncTypeConversionPatterns + +call operands/results: + convert callee signature and call sites together + +block arguments and branch operands: + convert target block arguments and predecessor operands in the same conversion + current implementation provides project-local OneToN patterns for cf.br, + cf.cond_br, and cf.switch because MLIR only provides the generic + BranchOpInterface helper for ordinary 1:1 dialect conversion, not for VMI + 1:N physicalization. + +scf.if/scf.for region yields and results: + use scf::populateSCFStructuralOneToNTypeConversions + otherwise write explicit OneToN patterns around RegionBranchOpInterface relations +``` + +如果当前 LLVM/MLIR 版本没有提供对应 OneToN helper,就补项目内 custom `OneToNConversionPattern`。 +选择标准不是“少写代码”,而是能否正确处理 1:N result、block argument、region yield 和 +recursive/function SCC。 + +当前实现的结构转换分工如下: + +```text +upstream OneToN helper: + func.func / func.return / func.call + scf.if / scf.for / scf.while and common SCF structural cases + +project-local OneToN structural patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只做一件事:按照 `OneToNTypeMapping` 展平/重建 operand、result、 +successor operand 和 block argument。它们不能内嵌 VMI layout 语义,也不能通过 defining op +重新推导物理寄存器列表。VMI 语义只出现在各个 `pto.vmi.*` 的 `OneToNOpConversionPattern` 中。 + +OneToN conversion 的执行顺序: + +```text +1. Populate structural conversion patterns. +2. Populate VMI semantic op lowering patterns. +3. Populate helper lowering/materialization patterns. +4. applyPartialOneToNConversion on the module. +5. Run final residual verifier as the hard legality gate. +``` + +如果 conversion 或 final gate 失败,诊断必须区分: + +```text +unsupported VMI semantic op +unsupported layout materialization path +unconverted function/control-flow boundary +unexpected VMI helper residual +unexpected unrealized_conversion_cast +``` + +这样 pass 边界就是清楚的: + +```text +pto-validate-vmi-ir: + verifier/walk, no conversion + +vmi-layout-assignment: + global per-value layout solver, then IR materialization + +vmi-to-vpto: + OneToNTypeConversion-based 1:N physicalization and final legality gate +``` + +### Concrete Pass Skeleton + +整个 pipeline 按下面的 hard contract 串起来: + +```text +raw VMI producer + -> pto-validate-vmi-ir + -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto + -> canonicalize/cse + -> final residual verifier +``` + +The `ptoas --enable-vmi` driver entry uses this sequence before the existing VPTO backend pipeline. +The test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is +wired through the user-facing compiler driver. The optimization passes are legal-to-legal VMI rewrites; removing one +may affect quality or reject fewer/fewer optimized forms, but it must not make `vmi-to-vpto` recover hidden context. + +各阶段之间只通过 IR 传递状态,不通过 pass-private side table 传递语义。也就是说: + +```text +layout assignment output: + VMI value type already contains layout + VMI mask type already contains layout + concrete b8/b16/b32 granularity + required layout conversion already appears as pto.vmi.ensure_* or rematerialized producer + +vmi-to-vpto input: + may contain pto.vmi.* semantic ops and helper ops + must not contain layout-free VMI type + function signatures and op/module TypeAttr or TypedAttr payloads are part of this invariant, + not just SSA operands/results + +vmi-to-vpto output: + must not contain pto.vmi.* op/type/helper + must not contain unrealized_conversion_cast + function type attributes and any other op/module TypeAttr or TypedAttr payloads must not contain !pto.vmi.* +``` + +This prevents a fragile design where `vmi-to-vpto` has to rediscover layout decisions from defining ops. A VMI value +may be a function argument, block argument, `scf.if` result, `scf.for` carried value, or branch target argument; none +of those has a useful defining op. + +#### Layout Assignment State + +`vmi-layout-assignment` should be implemented as one module-level solver object: + +```cpp +struct DataValueState { + Value value; + VMIVRegType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr naturalLayout; // producer-preferred layout + SmallVector uses; // consumer requirements +}; + +struct MaskValueState { + Value value; + VMIMaskType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr requestedLayout; + StringRef requestedGranularity; // b8/b16/b32 after inference + SmallVector uses; // consumer layout/granularity requests +}; + +struct LayoutUseRequest { + Operation *consumer; + VMILayoutAttr layout; + StringRef reason; // add/select/store/widen-source/etc. +}; +``` + +The solver runs in phases: + +```text +1. collect all VMI data/mask SSA values, including block arguments +2. add equivalence constraints +3. add producer natural-layout constraints +4. add consumer layout/granularity requests +5. solve each equivalence class +6. insert ensure_* for non-class-compatible uses +7. rewrite value types and function signatures +8. run pto-validate-vmi-layout-ir +``` + +Equivalence is only for cases where two logical values must have the same physical lane order: + +```text +add/sub/mul: + lhs == rhs == result + +cmpf/cmpi: + lhs == rhs + result mask requests lhs layout + element-width granularity + +select: + true_value == false_value == result + mask operand gets a use-site request for result layout + element-width granularity + +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == region_iter_arg[i] == yield[i] == result[i] + +cf.br/cf.cond_br: + successor operand[i] == successor block argument[i] + +direct internal func.call: + call operand[i] == callee argument[i] + call result[i] == all callee return operand[i] +``` + +Natural layout is not equivalence. For example: + +```text +extf f16 -> f32: + result natural layout = deinterleaved=2 + +extf f8 -> f32: + result natural layout = deinterleaved=4 + +truncf f32 -> f16: + result natural layout = contiguous + +truncf f32 -> fp8-like: + result natural layout = contiguous + +store: + consumer requests contiguous externally visible order +``` + +If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless an +explicit use-site `ensure_*` can represent the requested materialization. Baseline layout assignment does not +clone/rematerialize producers. The separate `vmi-layout-rematerialize` optimization may replace an `ensure_*` +with a cloned trivially replayable producer after the materialization request is visible in IR: + +```text +constant +broadcast +constant_mask +create_mask +``` + +For non-rematerializable producers, insert `pto.vmi.ensure_layout` immediately before the consumer that requested the +different layout. This is the conservative first implementation rule. It works for ordinary SSA values, block +arguments, loop-carried values, branch arguments, and call results because the helper is dominated by the value at the +use site and does not need to be hoisted across control flow. `DominanceInfo` may be used later to hoist duplicated +helpers as an optimization, but it must not be required for correctness in the first implementation. + +That helper is a real IR marker: if `vmi-to-vpto` cannot lower its requested conversion, the program fails with an +explicit unsupported materialization diagnostic. + +#### Layout Assignment Implementation Frame + +This pass is a normal `OperationPass`. It deliberately does not use `DialectConversion`, because there is +no stable `Type -> Type` rule until the pass has solved producer preference, consumer demand, and control-flow joins. +The implementation should look like this: + +```cpp +struct LayoutSolver { + ModuleOp module; + MLIRContext *ctx; + + DenseMap dataIds; + SmallVector dataNodes; + DenseMap maskIds; + SmallVector maskNodes; + + SmallVector dataUseRequests; + SmallVector maskUseRequests; + DenseMap> firstReturnOperandsByFunc; + + LogicalResult collectConstraints(); + LogicalResult rewriteIR(); +}; +``` + +The concrete state objects should carry only facts that are materialized back into IR: + +```cpp +struct DataNode { + Value value; + VMIVRegType surfaceType; + unsigned parent; + VMILayoutAttr naturalLayout; // null means no producer preference yet +}; + +struct MaskNode { + Value value; + VMIMaskType surfaceType; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; // empty until b8/b16/b32 is known +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +Do not store hidden layout state that `vmi-to-vpto` must rediscover. After this pass, a debugger should be able to read +the IR and know the chosen layout for every VMI value from its type alone. + +The pass body should stay simple: + +```cpp +void runOnOperation() override { + LayoutSolver solver(getOperation()); + if (failed(solver.collectConstraints()) || + failed(solver.rewriteIR()) || + failed(verifyLayoutAssignedVMIIR(getOperation()))) + signalPassFailure(); +} +``` + +The current implementation should map directly to this phase order: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module); +} +``` + +This order is intentional: + +```text +collect: + only discovers VMI values and block arguments. + +addConstraints: + only records equivalence, natural layout and consumer request facts. + It must not rewrite IR, because later CFG/call constraints may still merge + two values that were already seen. + +rewriteDataTypes: + commits solved data layouts to !pto.vmi.vreg type. + +insertDataUseMaterializations: + repairs use-site layout mismatch after the producer's committed type is known. + +inferMaskRequests: + uses already committed data layouts and element widths to infer concrete mask + layout/granularity requests. + +rewriteMaskTypes: + commits mask layout and b8/b16/b32 granularity. + +insertMaskUseMaterializations: + repairs mask layout/granularity mismatch. + +rewriteFunctionType: + updates function signatures last, after argument/result value types have been + rewritten. +``` + +Do not move `rewriteFunctionType` before use-site materialization. A function signature is the public shape of the +solved value class; changing it early makes call/return diagnostics depend on walk order and can hide an unresolved +use-site mismatch. + +Constraint collection is a module walk with explicit handlers. The important point is that each handler only records +facts; it must not rewrite while walking: + +```text +Data equivalence: + pto.vmi.addf/addi: lhs == rhs == result + pto.vmi.cmpf/cmpi: lhs == rhs + pto.vmi.select: true_value == false_value == result + pto.vmi.ensure_layout: source and result are not equivalent if layouts differ + +Data natural layout: + pto.vmi.extf f16->f32: result natural = deinterleaved=2 + pto.vmi.extf fp8-like->f32: result natural = deinterleaved=4 + pto.vmi.truncf: result natural = contiguous + pto.vmi.channel_merge with C inputs: result natural = deinterleaved=C + +Data use request: + pto.vmi.store: value requested as contiguous + pto.vmi.channel_split with C results: source requested as deinterleaved=C + op requiring a common operand/result layout: request producer class layout + +Mask request: + cmp result: same data layout as operands, granularity from element width + select mask: same data layout as selected value, granularity from element width + store mask path: same data layout as stored value, granularity from element width +``` + +Control flow should be handled as equivalence, not as local op preference: + +```text +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == body iter_arg[i] == yield[i] == result[i] + +scf.while: + before argument[i] == condition forwarded operand[i] == after argument[i] + after yield[i] == result[i] + +scf.execute_region: + every nested scf.yield operand[i] == execute_region result[i] + +scf.index_switch: + every case/default yield operand[i] == index_switch result[i] + +cf.br: + operand[i] == destination block argument[i] + +cf.cond_br: + true operand[i] == true destination block argument[i] + false operand[i] == false destination block argument[i] + +cf.switch: + default operand[i] == default destination block argument[i] + case k operand[i] == case k destination block argument[i] + +func.call: + only direct internal callees are supported in the first implementation + call operand[i] == callee argument[i] + call result[i] == every corresponding callee return operand[i] +``` + +Function returns need one extra bookkeeping rule. A function result slot has one public layout in the function type, so +all `func.return` operands at the same index must be equivalent: + +```text +first return operand[i] == every later return operand[i] +function result type[i] is rewritten from the solved type of return operand[i] +call result[i] == every corresponding callee return operand[i] +``` + +If two return paths naturally produce incompatible layouts, the pass should report `VMI-LAYOUT-CONTRACT` instead of +silently choosing one path: + +```mlir +^a: + %x = pto.vmi.extf %f16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %x : !pto.vmi.vreg<128xf32> // natural deinterleaved=2 + +^b: + %y = pto.vmi.extf %f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + return %y : !pto.vmi.vreg<256xf32> // different result shape/layout, invalid by verifier/type first +``` + +For equal result shape but incompatible producer preferences, the same rule applies: + +```text +return slot 0 from f16->f32 path: natural deinterleaved=2 +return slot 0 from f8E4M3FN->f32 path with the same logical result shape: natural deinterleaved=4 +diagnostic: VMI-LAYOUT-CONTRACT: conflicting natural layouts ... +``` + +External declarations with VMI types are not a layout problem; they are ABI materialization. The first implementation +must reject them before rewriting: + +```text +VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan +``` + +The rewrite phase has three ordered steps: + +```text +1. Rewrite all data SSA value types to !pto.vmi.vreg. +2. Rewrite all mask SSA value types to !pto.vmi.mask. +3. Repair use-site mismatches by either rematerializing a cheap producer or inserting an explicit helper. +``` + +Rematerialization is allowed only when replaying the producer cannot change memory, control flow, or execution count +semantics: + +```text +allowed: + pto.vmi.constant splat + pto.vmi.broadcast + pto.vmi.constant_mask + pto.vmi.create_mask + +not allowed in the first implementation: + load + arithmetic result + conversion result + shuffle/channel_split/channel_merge result + value crossing a call boundary or block argument +``` + +If rematerialization is not legal, insert: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +``` + +These helpers make the unresolved materialization explicit. `vmi-layout-assignment` is allowed to create them; +`vmi-to-vpto` is responsible for proving and lowering them. If lowering cannot prove the physical transform, the final +diagnostic should be an unsupported layout/materialization diagnostic, not silent incorrect code. + +Layout assignment completion checks: + +```text +1. No surface !pto.vmi.vreg remains. +2. No surface !pto.vmi.mask remains. +3. Every VMI function argument, result, block argument, branch operand, call operand, and return operand has the + layout-assigned type selected by the solved equivalence class. +4. Every consumer-specific mismatch is represented by an explicit pto.vmi.ensure_* op immediately before that + consumer. Optional optimization passes may later replace selected helpers with rematerialized cheap producers. +5. External declarations with VMI types are rejected; they are not rewritten into an implicit ABI. +``` + +#### OneToN Conversion Details + +`vmi-to-vpto` should use MLIR `OneToNTypeConversion` for all structural rewriting that involves VMI values: + +```text +OneToNTypeConverter: + !pto.vmi.vreg -> !pto.vreg... + !pto.vmi.mask -> !pto.mask... + +Patterns: + framework structural OneToN patterns for func/return/scf + explicit OneToNOpConversionPattern for each pto.vmi semantic op + explicit helper patterns for pack/unpack/ensure_* + +Final gate: + reject residual pto.vmi.*, !pto.vmi.*, function signatures containing !pto.vmi.*, and unrealized_conversion_cast +``` + +The implementation is an `OperationPass` with this shape: + +```cpp +struct VMIToVPTOTypeConverter final : OneToNTypeConverter { + VMIToVPTOTypeConverter() { + addConversion([](Type t) { return t; }); + addConversion(convertVMIVRegType); + addConversion(convertVMIMaskType); + + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(module.getContext()); + populateVMIOneToNConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +The type converter must define one canonical physical ordering and every pattern must use that ordering: + +```text +!pto.vmi.vreg + -> chunks in logical order: + chunk0 lanes [0..P-1], chunk1 lanes [P..2P-1], ... + +!pto.vmi.vreg + -> part-major chunks: + part0 chunk0 lanes [0,2,4,...] + part0 chunk1 next even lanes + part1 chunk0 lanes [1,3,5,...] + part1 chunk1 next odd lanes + +!pto.vmi.vreg + -> part-major chunks: + part0 lanes [0,4,8,...] + part1 lanes [1,5,9,...] + part2 lanes [2,6,10,...] + part3 lanes [3,7,11,...] + +!pto.vmi.vreg + -> chunks in contiguous physical storage order + only derived group_slot(g) lanes contain semantic values + this layout is valid only for group reduce/broadcast exchange values + +!pto.vmi.mask + -> same part/chunk ordering as its data layout, one !pto.mask per physical part/chunk +``` + +`materializeVPTOToVMI` and `materializeVMIToVPTO` should use only `pto.vmi.pack` and `pto.vmi.unpack`. These ops are +conversion scaffolding; they are never valid final output. This makes accidental framework materialization visible in +the IR and easy to reject. + +Pattern population should be explicit: + +```cpp +void populateVMIOneToNConversionPatterns(VMIToVPTOTypeConverter &converter, + RewritePatternSet &patterns) { + populateFuncTypeConversionPatterns(converter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); +} +``` + +Use upstream OneToN helpers where they exist: + +```text +func.func / func.return / func.call: + populateFuncTypeConversionPatterns + +scf.if / scf.for / scf.while and common structural SCF: + scf::populateSCFStructuralOneToNTypeConversions +``` + +Use project-local OneToN patterns where the current MLIR version does not provide a complete 1:N structural rewrite: + +```text +cf.br +cf.cond_br +cf.switch +scf.execute_region +scf.index_switch +``` + +These project-local structural patterns should not know VMI semantics. They only flatten operands/results according to +`OneToNTypeMapping`, convert successor block argument lists, and rebuild the same control-flow op. + +#### Pattern Authoring Checklist + +Every new `pto.vmi.*` lowering pattern should answer the same questions before it is added to +`populateVMIOneToNConversionPatterns`: + +```text +1. Does the op require all data operands/results to have identical physical arity? + If yes, check every ValueRange size against the result mapping before emitting VPTO ops. + +2. Does the op consume a mask? + If yes, the mask must already have concrete granularity and the same physical ordering expected by the data + operand. The pattern must not reinterpret a pred mask by lane count alone. + +3. Does the op observe contiguous logical order outside the register file? + If yes, require contiguous layout or explicitly lower the ensure_layout/materialization before using load/store + style VPTO ops. + +4. Does the op have padding lanes? + If yes, prove padding is unobservable. For load-like ops this requires a full-read safety proof or a fallback. + For store-like ops this requires a true predicate that disables padding writes. + +5. Does the op have target-specific side effects or ordering, such as squeeze/compact/store coupling? + If yes, put that check in verifySupportedVMIToVPTOOps before conversion starts, so the pass fails before partial + rewriting. + +6. Can it create pto.vmi.pack/unpack or unrealized_conversion_cast through framework materialization? + If yes, the semantic pattern still may be correct, but final residual verification must reject any leftover helper. +``` + +This gives a concrete division of labor: + +```text +verifySupportedVMIToVPTOOps: + shape/target/path support checks that should fail before any rewrite. + +OneToNOpConversionPattern: + mechanical lowering for a preflight-approved case. + +verifyNoResidualVMIIR: + final hard gate for missed patterns, illegal materializations and hidden VMI type payloads. +``` + +Do not put target capability probing in a structural pattern. For example, a `cf.br` pattern must never ask whether +`deinterleaved=4` can be materialized. It only converts successor operands. The semantic op that created or consumes +the value is responsible for proving the VPTO lowering path. + +#### Converter Use By Pass + +The implementation should be reviewable with the following rule: + +```text +pto-validate-vmi-ir: + no TypeConverter, no ConversionTarget, no rewrite. + +vmi-layout-assignment: + no TypeConverter for choosing layouts. + It may use RewriterBase after solving, but not DialectConversion as the solving model. + +vmi-to-vpto: + must use OneToNTypeConverter for VMI types. + must use OneToNOpConversionPattern for semantic VMI ops. + should use upstream func/scf OneToN helpers when available. + may add project-local structural OneToN patterns only for missing framework coverage. +``` + +The main reason is not style. It is correctness across values without defining ops: + +```mlir +^bb0(%x: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + cf.br ^bb1(%x : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + +^bb1(%y: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + %z = pto.vmi.addf %y, %y + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + ... +``` + +`%y` has no defining VMI op. Its physical values are the converted block arguments produced by OneToN block signature +conversion. Any implementation that tries to recover physical parts from a defining op is therefore incomplete for +control flow, function arguments and loop-carried values. + +When writing semantic `OneToNOpConversionPattern`, do not infer physical parts from a defining op. Use the OneToN +adaptor's per-original-operand `ValueRange`: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + ... + rewriter.replaceOp(op, physicalResults, adaptor.getResultMapping()); +} +``` + +Every VMI semantic lowering then follows the same shape: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + +for each physical part index i: + emit physical VPTO op for lhsParts[i], rhsParts[i] -> resultTypes[i] + +replace op with all physical results using adaptor.getResultMapping() +``` + +This convention is mandatory for values crossing control flow. For example an `scf.for` iter arg has no defining op; +its physical parts are the converted block arguments created by OneToN signature conversion. + +The concrete pattern shape is: + +```cpp +LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange in0 = adaptor.getIn0(); + ValueRange in1 = adaptor.getIn1(); + TypeRange outTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (in0.size() != in1.size() || in0.size() != outTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, outType] : llvm::enumerate(outTypes)) { + results.push_back(rewriter.create(op.getLoc(), outType, + in0[i], in1[i]).getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +For non-VMI operands, use a helper like `getSingleValue(op, adaptor.getOffset(), "...")` and fail if the framework +unexpectedly expanded them. This catches malformed conversion rules early. + +#### Semantic Lowering Buckets + +The first implementation should split VMI op lowering into four buckets: + +```text +identity/helper: + pack, unpack, ensure_layout identity/materialization cases, ensure_mask_* identity case + +per-part elementwise: + addf, addi, subf, subi, mulf, muli, divf, minf, maxf, negf, absf, absi, sqrt, exp, ln, relu, andi, ori, xori, shli, shrui, not, cmpf, cmpi, select + +per-part predicate: + mask_and, mask_or, mask_xor, mask_not + +layout-producing conversion: + extf, truncf, bitcast + +externally ordered memory: + load, store + +value-indexed accumulation: + dhist, chist +``` + +Per-part elementwise ops are straightforward only when all operands/results already share the same assigned layout: + +```text +logical deinterleaved=2 value: + part0 contains logical lanes 0, 2, 4, ... + part1 contains logical lanes 1, 3, 5, ... + +vmi.addf/subf/mulf on two such values: + emit the matching VPTO per-part op for part0_lhs, part0_rhs + emit the matching VPTO per-part op for part1_lhs, part1_rhs +``` + +This preserves logical lane semantics because each physical part contains the same logical lane subset for all +operands and the result. + +Memory ops are different because their observable semantics are contiguous logical order: + +```text +vmi.store of deinterleaved=2: + cannot blindly store part0 then part1 as the final memory order + must use a store plan that writes logical lane 0,1,2,3,... order + or materialize source to contiguous before physical store +``` + +Therefore `store` lowering must either: + +```text +1. consume contiguous layout directly, or +2. lower ensure_layout(deinterleaved -> contiguous), then store, or +3. use target store instructions whose dist mode proves contiguous external order +``` + +The first implementation uses option 2 for full physical chunks: + +```text +vmi.load: + emit contiguous physical vlds chunks in memory order + materialize contiguous -> assigned result layout + +vmi.masked_load: + only when the full physical read footprint is proven safe + emit contiguous physical vlds chunks in memory order + select loaded lanes against passthru with the VMI mask + if enable-stable-gather-masked-load is set, reject pto.vmi.masked_load with + a stable TODO diagnostic until the VGATHER2-based strict no-read path is + implemented + +vmi.store: + materialize assigned source layout -> contiguous + emit physical vsts chunks in memory order +``` + +Current direct memory lowering may only emit VPTO vector memory ops for +UB-backed memory. Concretely, a `!pto.ptr<..., ub>` is legal, a +`!pto.ptr<..., gm>` is not; a memref with `#pto.address_space` is legal, +and a memref without a memory-space attribute is treated as unknown/local to +this stage to preserve existing local-view tests. A memref explicitly marked +GM or another non-VEC space is rejected by `vmi-to-vpto`. + +GM-backed VMI memory is still a valid semantic source/sink before this pass, +but direct lowering does not perform GM<->UB movement. That must be represented +by an earlier/lower memory access plan, scratch materialization, or UB view +normalization before `vmi-to-vpto`; otherwise the diagnostic is +`VMI-UNSUPPORTED` and names the GM-backed source/destination. + +For `deinterleaved=2`, `vldsx2 DINTLV_B*` and `vstsx2 INTLV_B*` are valid optimization candidates because the ISA has +an explicit two-stream de/interleave memory distribution mode. This should be implemented only as a peephole inside +`vmi-to-vpto` after the generic plan is correct: + +```text +vmi.load result layout deinterleaved=2: + vldsx2 DINTLV_B* can directly produce part0/part1 chunks + +vmi.store source layout deinterleaved=2: + vstsx2 INTLV_B* can directly store part0/part1 chunks in logical memory order +``` + +Do not generalize this to `deinterleaved=4` unless the two-level dist composition is proven against the ISA. The +fallback for `deinterleaved=4` remains generic layout materialization plus ordinary memory ops. + +Direct `vmi.load` is lowered as full VPTO physical reads when the source memory kind/layout is supported and the +element type has a known physical lane width, even for non-full logical vectors. Masked/expand/gather read-style +operations still require the lowering to prove that the full physical read footprint is safe, or to use a future +true masked/non-faulting fallback. The current proof handles: + +```text +source is a statically shaped memref +offset is a constant non-negative index +offset + physical_arity(result) * lanes_per_physical_part <= static memref element count +``` + +When this proof holds, masked/expand read-style operations may still issue full `pto.vlds` chunks. The extra padding +lanes are not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, +dynamic offsets, dynamic memrefs, and insufficient static footprints remain unsupported for those stricter read-style +operations: + +```text +VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read +footprint (...; safe-read proof failed: ...) +VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering emits +pto.vlds/pto.vsts and requires UB-backed memory) +``` + +Store-style ops are different because inactive lanes can be made write-free with true predicates. `vmi.store`, +`vmi.masked_store` therefore support the explicit contiguous/deinterleaved tail-store +materialization paths described below. + +## 2. Slice 0: Type / Attr Bootstrap + +第一步只实现 VMI type、layout attr 和纯 helper,不实现任何 conversion pass。 + +### 2.1 `#pto.vmi.layout` + +定义 `VMILayoutAttr`: + +```mlir +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +建议内部参数: + +```text +kind: enum { contiguous, deinterleaved } +factor: int64_t +``` + +Verifier: + +```text +contiguous: + factor must be 1 + +deinterleaved: + factor must be 2 or 4 +``` + +禁止接受其它 spelling,例如 `stride2`、`stride4`、`parity`、`mod_split`、`blocked`。 + +### 2.2 `!pto.vmi.vreg` + +定义 `VMIVRegType`: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +elementType: Type +layout: Attribute // null means surface type before layout assignment +``` + +Verifier: + +```text +elementCount > 0 +elementType is scalar-like integer / float / index supported by VMI +layout is null or VMILayoutAttr +deinterleaved=4 only allowed when target registry later supports it; type verifier only checks shape +``` + +不要要求 `elementCount * bitwidth(elementType)` 是 256B 整数倍。 + +### 2.3 `!pto.vmi.mask` + +定义 `VMIMaskType`: + +```mlir +!pto.vmi.mask<128xpred> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +granularity: enum/string { pred, b8, b16, b32 } +layout: Attribute +``` + +Verifier: + +```text +elementCount > 0 +surface mask may use pred and no layout +layout-assigned mask must use b8/b16/b32 and must have VMILayoutAttr +pred mask must not carry layout +``` + +### 2.4 Lane Map Helper + +在 C++ 中提供纯函数 helper,供 verifier、layout assignment、VMI-to-VPTO 和测试共用: + +```text +getDataLanesPerPart(elementType) +getMaskLanesPerPart(granularity) +getVMIPhysicalArity(type) +mapLogicalLaneToPhysical(type, logicalLane) +mapPhysicalLaneToLogical(type, part, chunk, lane) +isPaddingLane(type, part, chunk, lane) +``` + +这些 helper 是 hard dependency。任何 pass 不能重新手写一套 arity 公式。 + +Slice 0 完成条件: + +```text +1. VMI type/attr 能 parse/print round-trip。 + Covered by vmi_type_attr_parse.pto. +2. 非法 layout factor、非法 mask granularity、非法 element count 有 verifier diagnostic。 + Covered by vmi_layout_factor_invalid.pto, + vmi_mask_granularity_invalid.pto, vmi_type_element_count_invalid.pto, + and vmi_mask_concrete_without_layout_invalid.pto / + vmi_mask_pred_with_layout_invalid.pto. +3. helper 单测或 lit 测试覆盖 contiguous/deinterleaved=2/deinterleaved=4 和非整 tile。 + Covered by vmi_to_vpto_type_only.pto and + vmi_to_vpto_type_arity.pto. +``` + +## 3. Slice 1: Minimal VMI Op Set + +不要一次实现 75 个 semantic op。第一批只实现能跑通 widening + elementwise + store 的闭环。 + +### 3.1 必选 semantic op + +Construction: + +```text +pto.vmi.constant +pto.vmi.broadcast +pto.vmi.iota +pto.vmi.create_mask +pto.vmi.constant_mask +``` + +`pto.vmi.from_elements` belongs to the eventual construction surface, but it is +not part of Slice 1. Do not synthesize it from ad hoc scalar lane inserts until +there is an explicit vreg immediate, scalar-insert, or scratch materialization +contract. + +Mask: + +```text +pto.vmi.mask_and +pto.vmi.mask_or +pto.vmi.mask_xor +pto.vmi.mask_not +``` + +Arithmetic / conversion: + +```text +pto.vmi.addf +pto.vmi.addi +pto.vmi.subf +pto.vmi.subi +pto.vmi.mulf +pto.vmi.muli +pto.vmi.fma +pto.vmi.divf +pto.vmi.minf +pto.vmi.maxf +pto.vmi.negf +pto.vmi.absf +pto.vmi.absi +pto.vmi.sqrt +pto.vmi.exp +pto.vmi.ln +pto.vmi.relu +pto.vmi.andi +pto.vmi.ori +pto.vmi.xori +pto.vmi.shli +pto.vmi.shrui +pto.vmi.not +pto.vmi.cmpf +pto.vmi.cmpi +pto.vmi.select +pto.vmi.extf +pto.vmi.truncf +pto.vmi.bitcast +``` + +`pto.vmi.shrui` represents logical right shift and lowers to `pto.vshr`. +`pto.vmi.shrsi` is intentionally not defined until VPTO exposes or documents +an arithmetic right-shift contract distinct from logical right shift. +Integer div/rem, integer casts, int-float casts, and index casts are also +intentionally outside the current VMI surface until signedness, rounding, +saturation, overflow/remainder, and target lowering contracts are explicit. + +Memory: + +```text +pto.vmi.load +pto.vmi.masked_load +pto.vmi.gather +pto.vmi.expand_load +pto.vmi.store +pto.vmi.masked_store +pto.vmi.scatter +pto.vmi.compress_store +``` + +Value-indexed accumulation: + +```text +pto.vmi.dhist +pto.vmi.chist +``` + +`pto.vmi.dhist` is a first-stage semantic op when histogram support is enabled. +`pto.vmi.chist` may share the surface verifier, but its final lowering must be +gated until the target CHISTv2 high-range cumulative semantics are verified. + +Current implementation scope note: + +```text +pto.vmi.gather / scatter +pto.vmi.active_prefix_index / compress / compress_store +future scan / contract style ops +``` + +These families are not first-stage completion blockers. The dialect surface may +define them, and the lowering may keep narrow direct paths when the target VPTO +contract is already explicit. Full semantic coverage for these families remains +out of scope until cross-chunk state, duplicate-index ordering, prefix carry, +compaction state, or contraction accumulation contracts are explicitly designed. +Unsupported shapes must fail before OneToN rewrite with `VMI-UNSUPPORTED`; they +must not fall through to residual-op diagnostics. + +Permutation: + +```text +pto.vmi.shuffle +pto.vmi.channel_split +pto.vmi.channel_merge +``` + +Internal helper: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +pto.vmi.unpack +pto.vmi.pack +``` + +### 3.2 Op Verifier Rules + +Construction op verifier: + +```text +constant value must be a dense elements attr, and its element type/count must match the result vreg +broadcast scalar type must match the result element type +constant_mask value must be a dense elements attr, must have i1 element type, and its element count must match the +result mask +create_mask may produce surface pred mask or concrete layout-assigned mask +mask_and/mask_or/mask_xor/mask_not require all mask operands/results to have the same logical lane count; if any +mask is layout-assigned, all masks must carry the same layout and granularity +``` + +Elementwise op verifier: + +```text +all data operands have same logical lane count +all data operands have same element type except documented conversion op +if any operand has layout, all layouted operands/results must agree +surface op may have no layout before vmi-layout-assignment +``` + +`select` verifier: + +```text +mask lane count == true/false/result lane count +mask layout must match data layout after layout assignment +mask granularity must match selected element width after layout assignment +``` + +`extf/truncf` verifier: + +```text +source/result lane count equal +source/result element types are float +bitwidth changes in the expected direction +truncf rounding attr, when present, must be A/H and currently only applies to + f32 -> !pto.hif8 +``` + +Memory op verifier: + +```text +load memory element type must match result VMI data element type when the source is PtrType or MemRefType +store memory element type must match stored VMI data element type when the destination is PtrType or MemRefType +``` + +Histogram op verifier: + +```text +dhist/chist acc type must be !pto.vmi.vreg<256xui16> +dhist/chist result type must match acc type +source type must be !pto.vmi.vreg +mask logical lane count must match source logical lane count +surface mask may be pred; after layout assignment it must be b8 contiguous +source/result/acc must not carry layout before vmi-layout-assignment +layout-assigned dhist/chist requires contiguous source, mask, acc, and result +``` + +`shuffle` verifier: + +```text +static mask length == result lane count +each mask index selects an existing source logical lane +result element type == source element type +no padding lane may be selected +``` + +`channel_split` verifier: + +```text +result count C >= 2 +input lane count N == C * M +each result is vreg +channel c result semantics: out[c][i] = input[i * C + c] +if any source/result carries layout, all must carry layout +for C=2/4, layout-assigned source must be contiguous or deinterleaved=C +layout-assigned results must be contiguous +``` + +`channel_merge` verifier: + +```text +operand count C >= 2 +all operands have same M and element type T +result is vreg +result semantics: result[i * C + c] = input[c][i] +if any input/result carries layout, all must carry layout +layout-assigned inputs must be contiguous +for C=2/4, layout-assigned result must be contiguous or deinterleaved=C +``` + +`ensure_layout` verifier: + +```text +source/result are both VMIVRegType +same elementCount and elementType +source/result both layout-assigned +source layout may equal result layout; that is a canonical no-op +``` + +`ensure_mask_layout` verifier is identical except it uses `VMIMaskType` and preserves granularity. + +`ensure_mask_granularity` verifier: + +```text +source/result are both VMIMaskType +same elementCount +same layout +source/result granularity are b8/b16/b32 +logical predicate value must be preserved +``` + +`pack/unpack` verifier: + +```text +VMI side must be layout-assigned +physical operand/result count == getVMIPhysicalArity(VMI type) +physical data types are !pto.vreg +physical mask types are !pto.mask +ordering is the shared Physical Arity helper order +``` + +Slice 1 完成条件: + +```text +1. Every Slice 1 op parses, prints, and has negative verifier tests. + Arithmetic/mask/helper verifier coverage includes vmi_elementwise_kind_invalid.pto, + vmi_mask_logic_invalid.pto, vmi_ensure_layout_surface_invalid.pto, + vmi_unpack_arity_invalid.pto, and vmi_pack_arity_invalid.pto. +2. Helper ops are marked internal in docs and rejected by final VMI-to-VPTO gate if residual. +3. `channel_split/channel_merge` have tests proving shuffle-equivalent lane order. +``` + +## 4. Slice 2: VMI Producer Boundary Verifier + +VMI core implementation starts from VMI IR. Producer-specific import is outside this manual's core path. + +实现 `PTOValidateVMIIR.cpp` 中的 VMI boundary verifier: + +```text +recommended pass name: pto-validate-vmi-ir +anchor: func::FuncOp or ModuleOp +source file: lib/PTO/Transforms/PTOValidateVMIIR.cpp +``` + +Boundary verifier checks: + +```text +all logical vector values use !pto.vmi.vreg / !pto.vmi.mask +all logical vector behavior is represented by pto.vmi semantic ops +surface VMI values before layout assignment do not carry layout +no physical VPTO op appears before vmi-to-vpto +no hidden side table is required to interpret VMI values +scalar/tensor/debug/transform boundary has already been resolved by producer +``` + +Slice 2 完成条件: + +```text +1. VMI-native positive tests pass boundary verification. + Covered by vmi_producer_boundary_valid.pto. +2. Physical VPTO op before VMI-to-VPTO is rejected. + Covered by vmi_producer_boundary_physical_invalid.pto, including both + physical function types and physical VPTO ops. +3. Layout-assigned type before layout assignment is rejected unless the test explicitly starts after layout assignment. + Covered by vmi_producer_boundary_layout_invalid.pto and + vmi_producer_boundary_mask_layout_invalid.pto. +4. Missing VMI type/op invariants produce `VMI-PASS-INVARIANT` or a more specific diagnostic. + Covered by vmi_producer_boundary_non_vmi_op_invalid.pto, + vmi_producer_boundary_helper_invalid.pto, and the producer-boundary + TypeAttr nested/surface/layout invalid tests. +``` + +## 5. Slice 3: `vmi-layout-assignment` + +推荐实现为 pass: + +```text +recommended pass name: vmi-layout-assignment +anchor: ModuleOp +source file: lib/PTO/Transforms/VMILayoutAssignment.cpp +``` + +`vmi-layout-assignment` 必须是 module 级 pass。函数参数、`func.return` operand、 +`func.call` operand/result 和 callee signature 需要在同一个约束图里求解;函数级 pass +只能看到局部 body,无法安全地同步 callsite 和 callee。 + +### 5.1 Internal Data Model + +Build one layout node per VMI SSA value: + +```text +Operation result +BlockArgument +Region yield operand +Function argument/result +Call operand/result +``` + +Each node records: + +```text +logical type: VMIVRegType or VMIMaskType +allowed layouts: bitset {contiguous, deinterleaved2, deinterleaved4} +required mask granularity: pred/b8/b16/b32 or unknown +natural layout preference +hard constraints +``` + +No information required by later passes may live only in this data structure. After the pass, type/attr/op +operands must fully describe the result. + +### 5.2 Transfer Functions + +Minimum Slice 3 transfer functions: + +```text +constant/broadcast/create_mask/constant_mask: + rematerializable in any legal consumer layout + +mask_and/mask_or/mask_xor/mask_not: + all mask operands/results same layout and granularity + +addf/addi/subf/subi/mulf/muli/divf/minf/maxf/negf/absf/absi/sqrt/exp/ln/relu/andi/ori/xori/shli/shrui/not/cmpf/cmpi/select: + all data operands/results same layout + mask layout follows data layout + +extf f16 -> f32: + result natural layout = deinterleaved=2 + source requires contiguous layout for the direct vcvt part=EVEN/ODD path + partial/tail source chunks are supported when they still fit in one physical + source chunk and produce the natural two-part result; source padding lanes map + only to result padding lanes + +extf f8 -> f32: + result natural layout = deinterleaved=4 + source requires contiguous layout for the direct vcvt part=P0/P1/P2/P3 path + partial/tail source chunks are supported under the same one-source-chunk + contract; source padding lanes map only to result padding lanes + +truncf f32 -> f16: + can consume deinterleaved=2 and produce contiguous + current implementation records a deinterleaved=2 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + partial/tail source pairs are supported when the two deinterleaved source + parts pack into one contiguous result chunk; source padding lanes map only to + result padding lanes + +truncf f32 -> fp8-like: + can consume deinterleaved=4 and produce contiguous + current implementation records a deinterleaved=4 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + The lowering emits four pto.vcvt operations with part=P0/P1/P2/P3, then ORs + the mutually exclusive partial destination registers into one contiguous fp8 + result. This mirrors the hardware packed-4 contract: each source part owns + one quarter of the destination byte lanes, so the final externally visible + vector remains logical lane order 0..N-1 after the merge. + default round mode is result-type specific: f8E4M3/f8E5M2 use rnd=R, hif8 + uses rnd=A. hif8 may explicitly request hybrid lowering with + pto.vmi.truncf {rounding = "H"}, which forwards rnd=H to every packed part. + +bitcast: + source and result layouts must match + source/result total logical bits must match + current implementation supports contiguous/deinterleaved layouts with identical + physical arity when every source/result physical chunk carries the same number + of logical bits. This covers full chunks and partial/tail chunks such as + 65xf32 -> 130xi16, where the second physical chunk carries 32 logical bits on + both sides, and uneven deinterleaved tails such as 129xf32 -> 129xi32. + Partial/tail bitcast remains unsupported if source padding bits would become + result logical bits. group_slots bitcast follows the same rule: it is valid + only when the source/result group_slots layout is identical and every + physical group-slot chunk carries the same logical bit footprint. + +load: + baseline result layout is deterministic from explicit layout attrs or the + producer natural layout; consumer-specific alternatives are represented by + ensure_layout and optimized later + +store: + baseline requests contiguous source layout + current implementation records a contiguous use-site request for vmi.store and + inserts pto.vmi.ensure_layout when the stored value class solved to a + non-contiguous layout. This makes externally visible memory order explicit in + IR before vmi-to-vpto. If explicit IR reaches vmi-to-vpto with a + deinterleaved=2/4 tail value, the direct lowering may still materialize it to + contiguous physical chunks first, but only when every deinterleaved part has + the same physical chunk count and therefore forms complete intlv groups. + +shuffle/channel_split/channel_merge: + default result layout contiguous unless the current op explicitly carries a + supported layout-preserving contract + current implementation supports pto.vmi.shuffle when every result physical + chunk forwards one source physical chunk with identical lane positions for + all non-padding result lanes. Result padding lanes are ignored by the + forwarding proof and remain unobservable after physicalization. This allows + whole-chunk projection/reordering under contiguous or explicit deinterleaved + layouts, including tail-prefix projections such as `[0, 1, 2, 3] -> + !pto.vmi.vreg<4xf32>`. Arbitrary lane permutation remains unsupported unless + the vselr index-vector path below can materialize it. + current implementation supports channel_split/channel_merge for 2 or 4 + channels. channel_split consumes a natural deinterleaved=C source and produces + contiguous per-channel results; channel_merge consumes contiguous per-channel + inputs and produces a natural deinterleaved=C result. The direct path also + accepts partial/tail channel groups when the virtual deinterleaved=C channel + layout has the same physical arity as the source/result representation, so + every physical group can be materialized with complete intlv/dintlv pairs. + Arity-changing partial groups such as splitting 4xf32 into two 2xf32 channels + remain unsupported. If a producer/consumer + requires dense contiguous layout, pto.vmi.ensure_layout materializes the + pto.vdintlv/pto.vintlv tree explicitly. Non-matching layouts and other channel + counts remain unsupported. +``` + +### 5.3 Solver Order + +Implement deterministic solving: + +```text +1. Collect region/SCC constraints, including scf/cf/function/call boundaries. +2. Propagate impossible layouts and required mask granularities. +3. Pick one layout per node using deterministic priority, not a cost model: + explicit layout already present on the VMI type, then unique natural layout, + then hard non-contiguous request, then contiguous. +5. Rewrite result/block/function types to layout-assigned VMI types. +6. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses that need conversion. +7. Run verifier gate. +``` + +Current implementation status: + +```text +implemented: + extf source -> contiguous use-site request for supported f16/fp8-like to f32 paths + truncf f32->f16 source -> deinterleaved=2 use-site request + truncf f32->fp8-like source -> deinterleaved=4 use-site request + single-use pto.vmi.load results can adopt a consumer-requested + layout before type rewrite; this covers direct memory producers such as + load -> truncf without inserting a redundant ensure_layout + vmi.store data operand -> contiguous use-site request + explicit VMI vreg layout is preserved as an initial solver constraint + explicit concrete VMI mask layout/granularity is preserved as an initial solver constraint + channel_split source -> deinterleaved=C use-site request + channel_split results -> contiguous natural layout + channel_merge inputs -> contiguous use-site request + channel_merge result -> deinterleaved=C natural layout + shuffle without explicit layouts -> contiguous source use-site request and contiguous result natural layout + shuffle with explicit source/result layouts -> preserve explicit layouts and let vmi-to-vpto prove chunk forwarding + pto.vmi.ensure_layout insertion for non-contiguous store operands + pto.vmi.ensure_layout insertion for truncf source materialization + pto.vmi.ensure_mask_layout / ensure_mask_granularity insertion for select mask operands + pto.vmi.create_mask / constant_mask rematerialization for select mask operands when the consumer needs a + different mask layout/granularity + splat pto.vmi.constant rematerialization for data operands when the consumer needs + a different layout + pto.vmi.broadcast rematerialization for data operands when the consumer needs + a different layout + scf.execute_region result/yield layout equivalence + scf.index_switch result/yield layout equivalence + scf.while state layout equivalence + +not yet implemented: + generic per-consumer layout request table for every VMI op + producer rematerialization for non-splat data constants and other cheap producers + cost model / target capability registry +``` + +Do not implement a local greedy pattern pass that ignores block arguments or function signatures. + +### 5.4 CFG Rules + +CFG 处理分两层。第一层是必须做的 layout equivalence:同一个控制流值在 +result、yield、region/block argument 之间必须形成同一个 layout/mask 约束组。第二层才是 +layout conflict resolution:当同一个 producer 的不同 consumers 希望不同 layout 时,插入 +`ensure_layout` 或 `ensure_mask_layout`。后续 `vmi-layout-rematerialize` 可以把部分 helper +替换成重放的纯构造 producer。 + +当前可落地的最小实现先做第一层。它不尝试在 branch 边界自动插入 conversion,因此下面这些 +关系一旦因为 natural layout 或 mask granularity 冲突无法合并,必须报 `VMI-LAYOUT-CONTRACT`, +不能默默选择某一边。 + +`scf.if` equivalence: + +```text +for each result index i: + scf.if result[i] + == then scf.yield operand[i] + == else scf.yield operand[i] +``` + +如果 value 是 `!pto.vmi.vreg`,合并 data layout 约束;如果 value 是 +`!pto.vmi.mask`,合并 mask layout 和 granularity 请求。这样 `%m = scf.if ... -> +!pto.vmi.mask` 后被 `vmi.select` 消费时,select 对 `%m` 推出的 `b8/b16/b32 + layout` +会传播回两边 yield 的 mask producer。 + +`scf.for` equivalence: + +```text +for each iter_arg index i: + init_arg[i] + == region_iter_arg[i] + == scf.yield operand[i] + == scf.for result[i] +``` + +这条规则避免 loop-carried value 每次迭代改变 layout。对于 `extf f16->f32` 作为 init、 +loop body 内部 `addf` 并 yield 的 case,`extf` 的 natural layout `deinterleaved=2` +必须稳定传递到 `%acc` region arg、`scf.yield` 和 loop result。 + +`cf.br` / `cf.cond_br` equivalence: + +```text +for each successor operand index i: + branch successor operand[i] + == successor block argument[i] +``` + +当前实现覆盖标准 `cf.br`、`cf.cond_br` 和 `cf.switch`。其中 `cf.switch` 的 default operands +与 default destination block arguments 按 index 建 layout 等价关系;每个 case operand segment +与对应 case destination block arguments 按 index 建 layout 等价关系。更泛化的 +`BranchOpInterface` op 如果携带 VMI type,后续要么补对应 mapping,要么在 layout assignment +阶段明确 diagnostic,不能让 hidden default layout 穿过去。 + +当前实现支持携带 VMI value 的 `scf.execute_region`:execute_region result 与直属 region terminator +`scf.yield` operands 按 result index 合并到同一个 layout 等价类。嵌套 region 内属于其他 op 的 +`scf.yield` 不参与 execute_region 的等价关系。 + +当前实现支持携带 VMI value 的 `scf.index_switch`:default/case region `scf.yield` operands 与 +index_switch results 按 result index 合并到同一个 layout 等价类。 + +当前实现支持携带 VMI value 的 `scf.while`:init operand、before region argument、`scf.condition` +forwarded operand、after region argument、after region `scf.yield` operand 和 while result 按状态 +index 合并到同一个 layout 等价类。`scf.condition` 的 i1 condition 本身不参与 VMI layout 约束。 + +Function boundary: + +```text +internal functions may get specialized layouted signatures +external ABI must not expose VMI layout +recursive SCC requires fixed-point signature layout +``` + +当前实现支持 direct `func.call` 到同一 module 内带 body 的 `func.func`: + +```text +call operand[i] == callee argument[i] +call result[i] == every callee return operand[i] +same-result-index return operands inside one callee are equivalent +``` + +如果携带 VMI type 的 call 无法解析到带 body 的 direct callee,layout assignment 必须报 +`VMI-LAYOUT-CONTRACT`。后续如需支持 public/external ABI,必须先定义 VMI 值如何在 ABI +边界 materialize,不能把 layouted VMI type 暴露出去。 +当前实现明确拒绝携带 VMI type 的 `func.call_indirect`,因为它没有可解析的 direct internal +callee signature/body 可参与 layout constraint solving。 + +当前实现对携带 VMI type 的 external function declaration 报 `VMI-LAYOUT-CONTRACT`,因为还没有 +定义 VMI value 的外部 ABI materialization plan。没有 VMI type 的 external declaration 必须在 +`rewriteFunctionType` 中保持原签名,不能因为没有 entry block arguments 被改写成空签名。 + +`ptoas --enable-vmi` 额外拒绝 public `func.func` 的 VMI-typed signature: + +```text +VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan +``` + +这样 test-opt 仍可覆盖 internal/private function signature physicalization,用户入口则不会把 +layout-assigned VMI 值隐式暴露成 public ABI。 + +Slice 3 完成条件: + +```text +1. All VMI values have layout-assigned types after the pass. +2. All masks have b8/b16/b32 granularity after the pass. +3. CFG and call tests prove branch/yield/signature layout equality. +4. Multi-use rematerializable producer tests prove broadcast, constant, iota, + create_mask, and constant_mask rematerialization vs ensure_layout / + ensure_mask_* is deterministic. +5. The pass runs the layout-assigned VMI hard gate before returning, including + recursive TypeAttr/TypedAttr rejection; covered by + vmi_layout_assignment_post_gate_type_attr_invalid.pto. +``` + +## 6. Slice 4: `vmi-to-vpto` + +推荐实现为 pass: + +```text +recommended pass name: vmi-to-vpto +anchor: ModuleOp +source file: lib/PTO/Transforms/VMIToVPTO.cpp +``` + +第一步实现必须先落地 MLIR OneToN conversion 框架: + +```text +VMIToVPTOTypeConverter : OneToNTypeConverter: + !pto.vmi.vreg -> ordered !pto.vreg list + !pto.vmi.mask -> ordered !pto.mask list + +Structural patterns: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + project-local OneToN patterns for cf.br/cf.cond_br/cf.switch + project-local OneToN patterns for scf.execute_region/scf.index_switch + +VMI patterns: + OneToNOpConversionPattern for pack/unpack/ensure_*/semantic ops + +Final residual gate: + reject pto.vmi.*, !pto.vmi.*, unrealized_conversion_cast + scan SSA types, block argument types, function signatures, and op/module TypeAttr or TypedAttr payloads +``` + +这一步可以先支持 type-only physicalization 和 `pack/unpack` helper physicalization,但不能让未实现的 VMI semantic op 静默通过。 +如果还有 `pto.vmi.*` 或 VMI type 残留,必须报 `VMI-RESIDUAL-OP`。 + +当前 slice 支持 VMI function/input/block argument 展开成 physical arguments,并支持: + +```text +pto.vmi.unpack(layouted VMI aggregate) -> physical parts: + replace with OneToN adaptor source parts + +pto.vmi.pack(physical parts) -> layouted VMI aggregate: + replace with the physical parts through resultMapping + +pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity: + ensure_layout must compare the original VMI source/result layout attrs, not only the converted physical type list. + If source/result layouts are identical, replace with source parts. This identity case supports partial/tail physical + chunks because no lane reordering or packing is performed. + If deinterleaved=2 -> contiguous, emit one pto.vintlv. + If contiguous -> deinterleaved=2, emit one pto.vdintlv. + If deinterleaved=4 -> contiguous, emit the two-level pto.vintlv tree. + If contiguous -> deinterleaved=4, emit the reverse two-level pto.vdintlv tree. + ensure_mask_layout supports the same contiguous <-> deinterleaved=2/4 layout conversions with predicate + rearrange ops: + deinterleaved=2 -> contiguous: pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=2: pto.pdintlv_b8/b16/b32 + deinterleaved=4 -> contiguous: two-level pto.pintlv_b8/b16/b32 tree + contiguous -> deinterleaved=4: two-level pto.pdintlv_b8/b16/b32 tree + ensure_mask_granularity supports concrete b8/b16/b32 logical predicate-preserving conversion: + widening b8 -> b16 -> b32: split each physical chunk with pto.punpack LOWER/HIGHER + narrowing b32 -> b16 -> b8: pack physical chunk pairs with pto.ppack LOWER/HIGHER and merge halves with pto.por + b8 <-> b32 conversions are lowered as two adjacent steps through b16. + +pto.vmi.broadcast: + current direct lowering requires the physical result element width to be 8, + 16, or 32 bits, because the vdup is predicated by pto.mask. + Other semantic element types need a dedicated materialization contract before + vmi-to-vpto may lower them. + for each physical result part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical result element width + emit pto.vdup(scalar, all_true_mask) + This is layout-independent because every logical lane has the same scalar value. A deinterleaved layout simply + receives one identical vdup per partition/chunk; no vintlv/vdintlv is needed. + +pto.vmi.iota: + semantics: + ASC: result[lane] = base + lane + DESC: result[lane] = base - lane + supported element types follow pto.vci: + integer 8/16/32 and f16/f32 + contiguous full-chunk direct path: + for each physical chunk c: + chunk_base = base +/- c * lanes_per_part + emit pto.vci chunk_base {order = ASC|DESC} + deinterleaved layout requires strided index materialization because physical part p contains logical lanes: + p, p + factor, p + 2 * factor, ... + The required formula is: + ASC: base + p + factor * local_lane + DESC: base - p - factor * local_lane + The current lowering materializes this per physical chunk: + local = pto.vci 0 + scaled = pto.vmuls local, factor + ASC: result = pto.vadds scaled, base + part_offset + DESC: result = pto.vsub pto.vdup(base - part_offset), scaled + Partial/tail chunks are allowed. The physical padding lanes receive the natural continuation of the generated iota + sequence and remain padding/undef at the VMI semantic level; memory writes, masks, reductions, and other + externally-visible consumers must still obey the VMI padding rules. + +pto.vmi.constant_mask: + support dense bool constants for concrete b8/b16/b32 masks. For each physical chunk: + if the active lanes form a prefix: + emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + otherwise decompose the static bitset into active runs: + run [lo, hi) = prefix(hi) & ~prefix(lo) + combine runs with pto.por under an all-true predicate + pred-only masks remain unsupported until they have a concrete b8/b16/b32 consumer granularity. + +pto.vmi.mask_and / mask_or / mask_xor / mask_not: + for each physical predicate part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical mask granularity + mask_and emits pto.pand(lhs_part, rhs_part, all_true_mask) + mask_or emits pto.por(lhs_part, rhs_part, all_true_mask) + mask_xor emits pto.pxor(lhs_part, rhs_part, all_true_mask) + mask_not emits pto.pnot(source_part, all_true_mask) + +pto.vmi.addf / addi / subf / subi / mulf / muli / divf / minf / maxf / negf / absf / absi / sqrt / exp / ln / relu / andi / ori / xori / shli / shrui / not: + current direct lowering requires the physical element width to be 8, 16, or + 32 bits, because every emitted VPTO op is predicated by a materialized + pto.mask. VMI types such as index or f64 remain valid semantic + surface types only after a dedicated lowering contract exists; until then + vmi-to-vpto must report VMI-UNSUPPORTED before OneToN conversion. + This common predicate-maskability rule is necessary but not sufficient for + every target op. Direct lowering must also preflight the concrete VPTO/VISA + element contract before OneToN rewriting: + addf/subf/mulf -> pto.vadd/vsub/vmul support f16/bf16/f32 floating types + divf -> pto.vdiv supports f16/f32 floating types + minf/maxf -> pto.vmin/vmax support f16/bf16/f32 floating types + negf/absf/sqrt/exp/ln/relu -> pto.vneg/vabs/vsqrt/vexp/vln/vrelu support f16/f32 floating types + absi -> pto.vabs supports signless/signed i8/i16/i32 integer types + bf16/f8 remain legal VMI float-like semantic types for the ops whose VMI + semantics allow them, but vmi-to-vpto must report VMI-UNSUPPORTED until a + materialization plan or wider target contract exists. + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical element width + addf/addi emit pto.vadd(lhs_part, rhs_part, all_true_mask) + subf/subi emit pto.vsub(lhs_part, rhs_part, all_true_mask) + mulf/muli emit pto.vmul(lhs_part, rhs_part, all_true_mask) + divf emits pto.vdiv(lhs_part, rhs_part, all_true_mask) + minf emits pto.vmin(lhs_part, rhs_part, all_true_mask) + maxf emits pto.vmax(lhs_part, rhs_part, all_true_mask) + negf emits pto.vneg(source_part, all_true_mask) + absf/absi emit pto.vabs(source_part, all_true_mask) + sqrt emits pto.vsqrt(source_part, all_true_mask) + exp emits pto.vexp(source_part, all_true_mask) + ln emits pto.vln(source_part, all_true_mask) + relu emits pto.vrelu(source_part, all_true_mask) + andi emits pto.vand(lhs_part, rhs_part, all_true_mask) + ori emits pto.vor(lhs_part, rhs_part, all_true_mask) + xori emits pto.vxor(lhs_part, rhs_part, all_true_mask) + shli emits pto.vshl(lhs_part, rhs_part, all_true_mask) + shrui emits pto.vshr(lhs_part, rhs_part, all_true_mask) + not emits pto.vnot(source_part, all_true_mask) + +pto.vmi.fma: + semantic: + result = fused_multiply_add(lhs, rhs, acc) + It must not be decomposed to pto.vmi.mulf + pto.vmi.addf because VPTO VMULA + may produce different floating-point results from separate multiply and add. + layout assignment: + lhs, rhs, acc, and result belong to one data layout equivalence class. + current direct lowering: + source/result element type must be f16, bf16, or f32 + for each physical part: + materialize pto.pset_b16/b32 "PAT_ALL" from the physical element width + emit pto.vmula(acc_part, lhs_part, rhs_part, all_true_mask) + The VMI operand order is lhs, rhs, acc; the VPTO operand order is acc, lhs, rhs. + +pto.vmi.cmpf / cmpi: + current direct lowering has the same 8/16/32-bit physical element-width + precondition as elementwise arithmetic, so the result predicate can be + materialized as b8/b16/b32. + target element contract: + cmpf: f16/bf16/f32, matching VISA VCMP floating-point element types + cmpi: signless/signed/unsigned i8/i16/i32, matching VISA VCMP integer element types + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" as the seed predicate + canonicalize predicate to VPTO cmp_mode eq/ne/lt/le/gt/ge + emit pto.vcmp(lhs_part, rhs_part, seed_mask, cmp_mode) + supported cmpf ordered aliases: + oeq -> eq + one -> ne + olt -> lt + ole -> le + ogt -> gt + oge -> ge + supported cmpi signed aliases: + slt -> lt + sle -> le + sgt -> gt + sge -> ge + unsupported floating-point predicates such as ord/uno/ult/ule/ugt/uge must emit VMI-UNSUPPORTED until NaN-aware + predicate construction is designed. + unsupported unsigned integer predicates ult/ule/ugt/uge must emit VMI-UNSUPPORTED until VPTO integer signedness + materialization is explicit. + +pto.vmi.active_prefix_index: + semantic: + idx[i] = popcount(mask[0 .. i)) + result element type must be signless i8/i16/i32, and concrete mask granularity must match the result element width. + current direct lowering: + only contiguous layout + only one physical result/mask chunk + result and mask chunks must be full, with no padding logical lanes + materialize a zero vreg carrier with pto.vdup + emit pto.vusqz(carrier, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could affect the observable prefix + multi-chunk contiguous values need cross-chunk prefix carry + deinterleaved layouts need logical-lane-order prefix reconstruction + both must report VMI-UNSUPPORTED before OneToN conversion + +pto.vmi.compress: + semantic: + keep source lanes whose mask lane is true and compact them in logical lane order; inactive tail lanes are zero/undef + at the VMI semantic level unless consumed by an operation that defines them. + current direct lowering: + source/result/mask must be contiguous + source/result/mask must each materialize to one physical chunk + source chunk must be full, with no padding logical lanes + emit pto.vsqz(source, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could be squeezed into the observable result prefix + multi-chunk values need cross-chunk compaction and SQZN/carry planning + deinterleaved layouts need logical-lane-order compaction before physical part placement + compress_store is not implied by register compress; store-coupled VSQZ #st=1 and VSTUR require a separate + producer/consumer pairing plan + +pto.vmi.compress_store: + semantic: + store source lanes whose mask lane is true as a dense logical memory stream: + k = 0 + for lane in logical order: + if mask[lane]: + base[offset + k] = value[lane] + k += 1 + layout assignment: + value use is requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct lowering: + value and mask must be contiguous + value and mask must each materialize to one physical chunk + the value chunk must be full, with no padding logical lanes + destination must be a UB !pto.ptr because pto.vstur is pointer-only and UB-only + lower as: + store_base = pto.addptr destination, offset + squeezed = pto.vsqz(value, mask) + align0 = pto.init_align + align1 = pto.vstur align0, squeezed, store_base, "POST_UPDATE" + pto.vstar align1, store_base + The pto.vstur user is the required consumer that lets the VPTO LLVM emitter + set VSQZ #st=1. A plain register pto.vsqz must not be assumed to enqueue + SQZN for store. + unsupported cases: + memref or GM destination until an explicit pointer/materialization plan exists + partial/tail physical chunks, because padding mask lanes could be squeezed into memory + multi-chunk values, because they need cross-chunk active-count compaction and SQZN/VSTUR state planning + deinterleaved layouts, because compaction must be in logical lane order + +pto.vmi.reduce_addi: + semantic: + acc = init[0] + for lane in logical order: + if mask[lane]: + acc = acc + source[lane] // integer wraparound addition + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element width must be 32 bits; narrower vcadd widens its result and needs a separate result type plan + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be 1-lane VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + i8/i16 until widening result and init conversion are designed + partial/tail source chunks because padding lanes must not participate + floating-point add reduction without pto.vmi.reduce_addf {reassoc} + +pto.vmi.reduce_addf: + semantic: + requires {reassoc}; without it the verifier rejects the op + acc = init[0] + for lane in any reassociated tree over active logical lanes: + acc = acc + source[lane] + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be 1-lane VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of b32 physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + missing reassoc attr + f16 until accumulator precision and rounding contract are designed + partial/tail source chunks because padding lanes must not participate + +pto.vmi.group_load / pto.vmi.group_store: + semantic: + num_groups is the only static grouping attribute. + N = logical lane count; G = num_groups; S = N / G. + group_load reads each logical group as one contiguous row: + result[g * S + i] = source[offset + g * row_stride + i] + for 0 <= g < G and 0 <= i < S + group_store writes the inverse row mapping: + destination[offset + g * row_stride + i] = value[g * S + i] + row_stride is an index operand, measured in elements, and may be dynamic. + Tail/valid-lane information is not an attr; it must be represented by a + mask in the producing/consuming computation. The current direct + group_load/group_store path is for full physical chunks. + layout assignment: + group_load result natural layout is contiguous + group_store value use is requested as contiguous + current direct lowering: + source/value element width must be maskable by b8/b16/b32 + layout must be contiguous with full physical chunks + num_groups must evenly divide N, and the derived group size S must be a + multiple of the physical lanes + per part, so every physical chunk belongs to exactly one group + lower each physical chunk with pto.vlds/pto.vsts at: + offset + group * row_stride + chunk_in_group * lanes_per_part + unsupported cases: + derived group size splitting a physical chunk, because this needs partial-vreg + lane insertion/extraction or a gather/scatter plan + partial/tail physical chunks + GM-backed direct vector load/store paths not already accepted by the normal + VMI memory access plan + +pto.vmi.group_reduce_addf: + semantic: + requires {reassoc} + N = logical lane count; G = num_groups; S = N / G + L = physical lanes per 256B chunk for the element type. + The result carries #pto.vmi.layout, a group-slot + group-slot layout. It is not a dense vector layout: only slot lanes have + semantic values. Supported K values are: + K = 8 for VCGADD-style packed results, where group g is stored in + physical chunk floor(g / 8), lane g % 8. + K = 1 for row-local VCADD results, where group g is stored in physical + chunk g, lane 0. + for each group g: + result[group_slot(g)] = + reduce_add(source[g * S .. (g + 1) * S), mask in same range) + Non-slot lanes are not consumed by pto.vmi.group_broadcast. The current + direct lowering materializes them as zero where the hardware path does not + already define them. + The result remains a VMI vector with the same element type as the source, + but its logical lane count is G: one scalar result per group. Its layout + is an explicit group-slot layout that describes where those G scalars are + placed in physical registers. + layout assignment: + source use is requested as contiguous + result natural layout is #pto.vmi.layout + mask use is requested as contiguous with granularity derived from source + element width + current direct lowering: + source/result element type must be f32 + source and mask must have compatible full physical chunks. The result is + `GxT` group-slot data and may have different physical arity from the + source tile. + if S=8 for f32, lower each physical chunk with pto.vcgadd. This is the + hardware 32B VLane group reduction path for f32: each source chunk produces + eight 8-lane group sums in the low lanes of that physical chunk. The + lowering preserves this natural no-pack result. + Otherwise: + derived group size S must be a multiple of physical lanes per part + lower each source chunk with pto.vcadd, combine chunks in the same group + with pto.vadd under PAT_VL1, then place group g in the slot lane defined by + K. All other result chunks/lane values + are zero. + unsupported cases: + missing reassoc attr + f16 or integer group reductions until accumulator and result contracts are + designed + derived group size S that neither divides nor is a multiple of L + +pto.vmi.group_broadcast: + semantic: + source logical lane count is G; result logical lane count is N. + S = N / G. + source must carry #pto.vmi.layout. For each + group g, the source value is read from the slot lane defined by K. The + result broadcasts it back to each logical group: + result[g * S + i] = source[group_slot(g)] + layout assignment: + source use is requested as #pto.vmi.layout + result is consumer-driven. If no consumer requests another layout, it + defaults to contiguous. + current direct lowering: + source must carry #pto.vmi.layout with one + logical lane per group + result may be contiguous with full physical chunks + result may also be deinterleaved when S is large enough that every physical + result chunk stays inside one logical group, for example N=512, G=2, S=256, + L=64, deinterleaved=4. If the source is + #pto.vmi.layout, the source physical part is + selected by group id rather than by source chunk id. + derived group size S must divide or be a multiple of L for canonical + group-slot addressing + if result is contiguous and S < L, each physical chunk contains multiple group + slots. Lower by + creating an index vector [0...0, 1...1, ...] and applying pto.vselr to the + corresponding source chunk. + if S >= L and each result physical chunk belongs to one group, lower by + duplicating the first lane of that group's source chunk with pto.vdup LOWEST. + unsupported cases: + partial/tail physical chunks + derived group size S that neither divides nor is a multiple of L + deinterleaved small-group broadcast where one physical result chunk needs + values from multiple source chunks + +pto.vmi.reduce_maxf / pto.vmi.reduce_minf: + semantic: + acc = init[0] + for each active logical lane in logical lane order: + reduce_maxf: acc = max(acc, source[lane]) + reduce_minf: acc = min(acc, source[lane]) + result[0] = acc + inactive lanes inside each physical chunk follow VPTO identities: + reduce_maxf uses pto.vcmax, where inactive FP lanes behave as -INF + reduce_minf uses pto.vcmin, where inactive FP lanes behave as +INF + NaN and signed-zero behavior follows pto.vcmax/pto.vcmin for the chunk + reduction and pto.vmax/pto.vmin for serial chunk accumulation. The index + lane produced by pto.vcmax/pto.vcmin is ignored because VMI exposes only the + 1-lane value result. + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f16 or f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be 1-lane VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower reduce_maxf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmax(source_chunk, mask_chunk) + acc = pto.vmax(reduced, acc, first_lane) + result = acc + lower reduce_minf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmin(source_chunk, mask_chunk) + acc = pto.vmin(reduced, acc, first_lane) + result = acc + unsupported cases: + bf16/fp8/f64 until VPTO reduction and combine semantics are designed + partial/tail source chunks because padding lanes must not participate + integer min/max until signed/unsigned and inactive identity contracts are explicit + +pto.vmi.select: + current direct lowering is a storage-width select rather than a semantic + arithmetic op: source/result physical elements must be b8/b16/b32-maskable, + but signedness and float-vs-integer interpretation are not inspected. + for each physical part: + consume the corresponding physical predicate part + emit pto.vsel(true_part, false_part, predicate_part) + +pto.vmi.extf, direct path: + support 16-bit float-like contiguous source part -> f32 deinterleaved=2 result parts + materialize pto.pset_b16 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=EVEN/ODD) + partial/tail is valid when the logical lanes fit in the one physical source + part; PAT_ALL may convert padding lanes, but those lanes remain padding in + the deinterleaved result + support 8-bit contiguous source part -> f32 deinterleaved=4 result parts + materialize pto.pset_b8 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=P0/P1/P2/P3) + the same padding rule applies + reject other extf width/layout shapes until their exact part plan is implemented + +pto.vmi.truncf, direct path: + support f32 deinterleaved=2 source parts -> 16-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(even_f32_part, mask, rnd=R, sat=SAT, part=EVEN) + emit pto.vcvt(odd_f32_part, mask, rnd=R, sat=SAT, part=ODD) + materialize pto.pset_b16 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the two source parts pack into one physical + result part; converted padding lanes remain result padding + support f32 deinterleaved=4 source parts -> 8-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(p0_f32_part, mask, rnd=, sat=SAT, part=P0) + emit pto.vcvt(p1_f32_part, mask, rnd=, sat=SAT, part=P1) + emit pto.vcvt(p2_f32_part, mask, rnd=, sat=SAT, part=P2) + emit pto.vcvt(p3_f32_part, mask, rnd=, sat=SAT, part=P3) + result round is R for f8E4M3/f8E5M2, A for default hif8, or H for + hif8 truncf with {rounding = "H"} + materialize pto.pset_b8 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the four source parts pack into one physical + result part; converted padding lanes remain result padding + reject other truncf width/layout shapes until their exact pack plan is implemented + +pto.vmi.bitcast: + for each physical part: + emit pto.vbitcast(source_part) -> result_part_type + source/result layouts must match, physical arity must match, and every + corresponding physical chunk must carry the same number of logical bits. + This includes contiguous, deinterleaved, and identical group_slots layouts. + Padding bits may map only to result padding bits; any shape where source + padding would become result logical data remains unsupported. + +pto.vmi.channel_split / pto.vmi.channel_merge: + support 2-way and 4-way channel transforms for contiguous per-channel values + and matching deinterleaved=C merged values. + + channel_split C=2: + if the source layout is already deinterleaved=2, forward physical chunks + directly to the two contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 2*N contiguous chunks. For each pair of dense chunks: + %ch0_i, %ch1_i = pto.vdintlv %dense_2i, %dense_2i_plus_1 + Results are returned in per-channel order: + channel0 chunks..., channel1 chunks... + + channel_split C=4: + if the source layout is already deinterleaved=4, forward physical chunks + directly to the four contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 4*N contiguous chunks. The lowering is the same two-level pto.vdintlv + tree used by contiguous -> deinterleaved=4 materialization, but the + partition-major output is interpreted as four separate contiguous channel + results. + + channel_merge C=2/C=4: + inputs are consumed as per-channel contiguous chunks. + If the result layout is deinterleaved=C, the physical chunks are forwarded + directly in partition-major order. + If the result layout is contiguous, the lowering uses the reverse + pto.vintlv tree and returns dense contiguous chunks for the merged result. + + Unsupported: + channel counts other than 2 or 4 + non-matching channel input/result layouts + arity-changing or uneven partial physical channel groups that cannot form + complete intlv/dintlv groups + +pto.vmi.shuffle: + first try whole physical chunk forwarding cases: + source/result layouts are assigned + every non-padding lane in a result physical chunk maps to the same source physical chunk + source lane number equals result lane number inside the physical chunk + result padding lanes are ignored and remain semantically unobservable + + If forwarding fails, try vci-materializable vselr per physical chunk: + every result physical chunk has no padding lane + every lane in a result physical chunk maps to the same source physical chunk + source lane indices inside the chunk form one ASC or DESC consecutive sequence + materialize the index vector with pto.vci(base_lane, ASC|DESC) + emit pto.vselr(source_chunk, index_vector) + + Examples: + identity 128xf32 -> 128xf32: + indices = [0, 1, ..., 127] + forward dense chunks 0 and 1 + + second physical chunk 128xf32 -> 64xf32: + indices = [64, 65, ..., 127] + forward dense chunk 1 + + tail prefix 128xf32 -> 4xf32: + indices = [0, 1, 2, 3] + forward dense chunk 0 + lanes 4..63 of the physical result are padding lanes and are not part of + the logical vmi value + + chunk swap 128xf32 -> 128xf32: + indices = [64, 65, ..., 127, 0, 1, ..., 63] + forward dense chunks in order 1, 0 + + reverse one 64xf32 chunk: + indices = [63, 62, ..., 0] + index = pto.vci 63 {order = DESC} : i32 -> !pto.vreg<64xi32> + result = pto.vselr source_chunk, index + + Unsupported: + partial physical chunk projection whose observable result lanes are not + padding-safe forwarding, e.g. [1, 2, 3, 4] -> 4xf32 when it would require + shifting lanes rather than forwarding a whole physical chunk + broadcast, duplicate lanes, arbitrary non-affine permutation + current implementation emits VMI-UNSUPPORTED for these cases before + OneToN conversion, instead of leaving a generic residual VMI op. +``` + +`func.return` 携带 VMI operand 时必须通过 OneToN func/return structural pattern 展开成 physical +return operands。不能只取第一个 physical part;这种错误会导致函数类型已经返回两个 physical value, +但 `func.return` 只返回一个 value。 + +### 6.1 Type Conversion + +Use one shared physicalization helper: + +```text +VMIVRegType -> N physical !pto.vreg +VMIMaskType -> N physical !pto.mask +``` + +Physical result ordering must be: + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=K: + p0_chunk0, p0_chunk1, ..., p1_chunk0, ..., p(K-1)_chunkN +``` + +### 6.2 Structural Conversion + +The pass must convert: + +```text +operation results +block arguments +branch operands +cf.br / cf.cond_br successor block signatures +scf.if results and yields +scf.for iter_args and yields +func arguments/results +call operands/results +return operands +cf.br / cf.cond_br / cf.switch block arguments and successor operands +scf.execute_region results and yields: + current implementation uses a project-local OneToN structural pattern. +scf.index_switch results and yields: + current implementation uses a project-local OneToN structural pattern. +``` + +Do not rely on a defining op to recover parts. Any VMI value may come from a block argument or function +argument, so `unpack` must be valid on arbitrary layout-assigned VMI SSA values before final lowering. + +### 6.3 Op Lowering + +Internal helper lowering: + +```text +unpack: + replace with physical values in helper ordering + +pack: + materialize one logical VMI aggregate before it is immediately consumed by another VMI helper + must not remain after final gate + +ensure_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + identity source/result layouts do not require full chunks + if source/result layouts differ, either: + every source/result physical chunk is full, with no padding lanes; or + source/result both have complete contiguous/deinterleaved=2/4 materialization groups and their materialized + physical arity still equals the original VMI physical arity + arity-changing partial/tail layout conversion remains unsupported because it would need an explicit padding + packing/drop plan + otherwise report VMI-UNSUPPORTED before OneToN conversion + + compare the original VMI source/result layout attrs: + same layout: + forward the converted source parts + deinterleaved=2 -> contiguous: + %d0, %d1 = pto.vintlv %p0, %p1 + contiguous -> deinterleaved=2: + %p0, %p1 = pto.vdintlv %d0, %d1 + deinterleaved=4 -> contiguous: + %a0, %a1 = pto.vintlv %p0, %p2 + %b0, %b1 = pto.vintlv %p1, %p3 + %d0, %d1 = pto.vintlv %a0, %b0 + %d2, %d3 = pto.vintlv %a1, %b1 + contiguous -> deinterleaved=4: + %a0, %b0 = pto.vdintlv %d0, %d1 + %a1, %b1 = pto.vdintlv %d2, %d3 + %p0, %p2 = pto.vdintlv %a0, %a1 + %p1, %p3 = pto.vdintlv %b0, %b1 + + It is a bug to treat layout conversion as identity merely because both sides convert to the same + number of physical !pto.vreg values with the same type. For example: + !pto.vmi.vreg<128xf32, deinterleaved=2> + !pto.vmi.vreg<128xf32, contiguous> + both physicalize to two !pto.vreg<64xf32> values, but their logical lane order differs. + +ensure_mask_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + if source/result layouts differ, every source/result physical predicate chunk must be full, with no padding lanes + identity source/result layouts do not require full chunks + otherwise report VMI-UNSUPPORTED before OneToN conversion + + same-layout: + forward source parts + deinterleaved=2 -> contiguous: + use pto.pintlv_b8/b16/b32 on each partition pair + contiguous -> deinterleaved=2: + use pto.pdintlv_b8/b16/b32 on each dense pair + deinterleaved=4 -> contiguous: + use the same two-level tree as data layout conversion, replacing pto.vintlv with pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=4: + use the reverse two-level tree, replacing pto.vdintlv with pto.pdintlv_b8/b16/b32 + source/result granularity must be identical; granularity conversion belongs to ensure_mask_granularity. + +ensure_mask_granularity: + source/result layout and logical lane count must match. + source/result granularity must be concrete b8/b16/b32. + identity conversion forwards physical parts. + widening conversion: + b8 -> b16 or b16 -> b32 uses pto.punpack LOWER/HIGHER for each source physical chunk. + each source physical mask chunk can produce up to two result chunks in logical order. + narrowing conversion: + b32 -> b16 or b16 -> b8 uses pto.ppack LOWER for the low source chunk. + if a high source chunk exists, use pto.ppack HIGHER and merge the two partial masks with pto.por under PAT_ALL. + this handles odd tail groups because the missing high half is padding and remains zero. + multi-step conversion: + b8 -> b32 is b8 -> b16 -> b32. + b32 -> b8 is b32 -> b16 -> b8. +``` + +Elementwise lowering: + +```text +for each physical part: + lower add/cmp/select to corresponding VPTO op sequence + preserve source/result physical ordering + cmp predicates must be canonicalized before creating pto.vcmp: + eq/ne/lt/le/gt/ge pass through + ordered FP aliases oeq/one/olt/ole/ogt/oge map to eq/ne/lt/le/gt/ge + signed integer aliases slt/sle/sgt/sge map to lt/le/gt/ge + unordered/NaN-sensitive FP predicates are unsupported until represented explicitly + unsigned integer predicates are unsupported until signedness is represented explicitly +``` + +Producer lowering: + +```text +broadcast: + TypeConverter gives the ordered result physical types. + For each result physical vreg: + create all-true mask with the vreg element width + emit pto.vdup scalar -> that physical vreg + + This is valid for contiguous and deinterleaved layouts because splat has no lane-order dependence. + +constant: + Splat dense constants use the same path as broadcast: + create scalar arith.constant from the splat attribute + emit pto.vdup per physical result part + require the same 8/16/32-bit physical result element-width precondition as + broadcast + Non-splat dense constants need an explicit constant materialization strategy or must remain unsupported with a + precise diagnostic; do not synthesize an arbitrary lane sequence by scalar inserts unless that path is designed. + +create_mask / constant_mask: + constant active_lanes create_mask lowers per physical mask part: + clamp active_lanes to [0, logical lane count] + compute active prefix count for each physical mask chunk with the VMI lane-map helper + emit pto.pge_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a chunk prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + Dynamic active_lanes with contiguous layout lowers by chaining pto.plt_b8/b16/b32 over the physical chunks: + active_i32 = arith.index_cast active_lanes : index to i32 + active_i32 = minui(maxsi(active_i32, 0), logical_lane_count) + mask0, remaining0 = pto.plt_b* active_i32 + mask1, remaining1 = pto.plt_b* remaining0 + ... + Dynamic active_lanes with deinterleaved layout remaps one logical prefix into per-part dynamic lane counts before + chaining pto.plt_b*: + active_i32 = minui(maxsi(index_cast(active_lanes), 0), logical_lane_count) + part_count(part) = (active_i32 + factor - 1 - part) / factor + then chain pto.plt_b* independently for each partition in VMI physical order: + p0 chunks..., p1 chunks..., ... + dense constant_mask lowers per physical mask part: + first map logical lanes to physical predicate lanes using the assigned VMI layout + prefix chunks emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, emit pto.plt_b8/b16/b32 with a constant i32 count + non-prefix chunks are decomposed into static active runs: + prefix(hi) = pto.pge/plt for the run end + prefix(lo) = pto.pge/plt for the run begin + run = prefix(hi) & ~prefix(lo) using pto.pnot + pto.pand + chunk = run0 | run1 | ... using pto.por + +Unsupported diagnostics: + unexpected residual dynamic pto.vmi.create_mask after OneToN conversion: + VMI-UNSUPPORTED: dynamic pto.vmi.create_mask active_lanes could not be lowered by the current runtime predicate + generation plan + This is a final-gate diagnostic for malformed or newly unsupported dynamic shapes. The supported dynamic + contiguous/deinterleaved=2/deinterleaved=4 paths above must lower before this residual gate. + + non-splat pto.vmi.constant: + VMI-UNSUPPORTED: non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan + + unsupported partial/tail masked/expand read-style op: + VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe + full-read footprint (...; safe-read proof failed: ...) + GM-backed direct pto.vmi.load/masked_load/expand_load: + VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering + emits pto.vlds/pto.vsts and requires UB-backed memory) + unsupported partial/tail pto.vmi.store/masked_store: + VMI-UNSUPPORTED: pto.vmi. requires an 8/16/32-bit predicate-maskable element type and either full + physical chunks or contiguous/deinterleaved tail-store materialization, with UB-backed destination; unsupported + cases include values such as f64/index that have no b64 predicate representation, GM-backed destinations that + still need a memory movement/materialization plan, and uneven deinterleaved physical groups that cannot form + complete intlv groups + + unsupported non-identity partial/tail pto.vmi.ensure_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_layout cannot materialize the requested data layout conversion; unsupported cases + include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete intlv + groups + If the helper has a single consumer, the main diagnostic is emitted on the + consumer op and operand, including both the actual operand VMI type and the + required VMI type. For example, pto.vmi.truncf operand #0 can report + `!pto.vmi.vreg<128xf32, contiguous>` vs. + `!pto.vmi.vreg<128xf32, deinterleaved=4>` for f32->fp8. The failed + pto.vmi.ensure_layout conversion is attached as a note. + + unsupported non-identity partial/tail pto.vmi.ensure_mask_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion; unsupported + cases include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete + predicate intlv groups + + unsupported pto.vmi.ensure_mask_granularity: + VMI-UNSUPPORTED: non-identity mask granularity materialization requires concrete b8/b16/b32 masks with matching + lane count and layout (...) + + unsupported pto.vmi.extf direct path shape: + VMI-UNSUPPORTED: pto.vmi.extf supports only one contiguous 16-bit float-like or fp8-like physical source chunk to f32 + deinterleaved=2/4 results; partial/tail is allowed only when source padding maps to result padding + + unsupported pto.vmi.truncf direct path shape: + VMI-UNSUPPORTED: pto.vmi.truncf supports only f32 deinterleaved=2 source parts to one contiguous f16 result chunk + or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk + + unsupported pto.vmi.bitcast shape: + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical + arity and matching per-chunk logical bit footprints (...) + + unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: + VMI-UNSUPPORTED: pto.vmi.channel_split supports only 2 or 4 channels + VMI-UNSUPPORTED: pto.vmi.channel_merge supports only 2 or 4 channels + unsupported pto.vmi.channel_split / pto.vmi.channel_merge layout: + VMI-UNSUPPORTED: pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel + layout, and every result layout to be contiguous + VMI-UNSUPPORTED: pto.vmi.channel_merge requires every input layout to be contiguous and result layout to be + contiguous or matching deinterleaved channel layout +``` + +Width conversion lowering: + +```text +f16 -> f32: + supported direct path when source is contiguous and result is deinterleaved=2: + pto.vcvt part=EVEN produces logical lanes 0,2,4,... + pto.vcvt part=ODD produces logical lanes 1,3,5,... + source/result physical arity must be 1 -> 2 + +f8 -> f32: + supported direct path when source is contiguous and result is deinterleaved=4: + pto.vcvt part=P0/P1/P2/P3 produces the four modulo-4 lane partitions + source/result physical arity must be 1 -> 4 + +f32 -> f16: + supported direct path when source is deinterleaved=2 and result is contiguous: + pto.vcvt part=EVEN consumes even/source part 0 + pto.vcvt part=ODD consumes odd/source part 1 + pto.vor merges mutually exclusive f16 part results into one contiguous vreg + source/result physical arity must be 2 -> 1 + current default conversion attrs are rnd=R, sat=SAT + +f32 -> 8-bit fp-like: + supported direct path when source is deinterleaved=4 and result is contiguous: + pto.vcvt part=P0/P1/P2/P3 consumes the four source partitions + pto.vor merges mutually exclusive byte-lane part results into one + contiguous vreg + source/result physical arity must be 4 -> 1 + current default conversion attrs are rnd=R for f8E4M3/f8E5M2 and rnd=A for + hif8. pto.vmi.truncf {rounding = "H"} is accepted only for f32 -> hif8 + and forwards rnd=H to the emitted pto.vcvt operations. +``` + +Memory lowering: + +```text +vmi.load: + current direct memory path first reads contiguous physical chunks. The logical lane count must be an exact multiple + of the physical vreg lane count. + For each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + dense_i = pto.vlds base[offset_i] + + If the requested VMI result layout is contiguous, return the dense chunks directly. + If the requested VMI result layout is deinterleaved=2: + prefer pto.vldsx2 "DINTLV_B8/B16/B32" per physical chunk group: + %p0_i, %p1_i = pto.vldsx2 base[offset_i], "DINTLV_B*" + return results in VMI partition-major order: + p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ... + If the requested VMI result layout is deinterleaved=4 with exactly four physical parts: + use dense pto.vlds chunks followed by the reverse two-level pto.vdintlv tree. + + For larger multi-chunk deinterleaved=4 loads, apply the same conversion per contiguous chunk group and return + physical parts in VMI partition-major order: + deinterleaved=4: p0_chunks..., p1_chunks..., p2_chunks..., p3_chunks... + +vmi.store: + direct lowering requires value element width to be 8, 16, or 32 bits so the + emitted pto.vsts/pto.vstsx2 predicate can be materialized as b8/b16/b32. + contiguous layout with full physical chunks: + offset_i = base_offset + i * lanesPerPart + mask_i = pto.pset_b8/b16/b32 "PAT_ALL" + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks still use PAT_ALL + the final chunk computes valid_lanes = logical_lane_count - chunk_i * lanesPerPart + tail_mask_i = pto.plt_b8/b16/b32(valid_lanes) + pto.vsts tail_value_i, base[offset_i], tail_mask_i + padding lanes therefore have no externally visible store effect. + +deinterleaved store: + deinterleaved=2 with full physical chunks: + prefer pto.vstsx2 "INTLV_B8/B16/B32" per physical chunk group: + pto.vstsx2 p0_i, p1_i, base[offset_i], "INTLV_B*", all_true_mask + offset_i = base_offset + i * 2 * lanesPerPart + the vstsx2 dist mode writes logical lane 0,1,2,3,... order externally. + + current safe path lowers through proven register materialization before store: + deinterleaved=4 with exactly four physical parts: + use the two-level pto.vintlv tree, then store %d0/%d1/%d2/%d3 as contiguous chunks + + Larger multi-chunk deinterleaved=4 values use the same conversion per chunk group. The final store order is dense + chunk order, so external memory observes logical lane 0,1,2,... order. + +vmi.masked_load: + semantics: + if mask[lane] is true, result[lane] = memory[base + lane] + if mask[lane] is false, result[lane] = passthru[lane] + inactive mask lanes do not by themselves permit unsafe memory reads + current direct path: + result, passthru, and mask are requested as contiguous + full physical chunks can always use pto.vlds because every loaded lane is logical + partial/tail chunks require the same statically safe full-read proof as vmi.load + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + non-contiguous layouts + unsafe partial/tail read footprints + target true masked/non-faulting load and guarded/scratch fallback + +vmi.stride_load: + semantics: + result lane order is contiguous VMI logical order + source addresses are described by the VPTO block/repeat stride operands + mask false lanes are inactive for the underlying block-strided load + layout assignment: + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + source must be !pto.ptr + result and mask must be one contiguous physical chunk + base = pto.addptr source, offset + result = pto.vsldb base, block_stride, repeat_stride, mask + unsupported cases: + multi-chunk result or mask + non-contiguous layouts + memref/gm source + +vmi.stride_store: + semantics: + value lane order is contiguous VMI logical order + destination addresses are described by the VPTO block/repeat stride operands + mask false lanes do not write memory + layout assignment: + value use is requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct path: + destination must be !pto.ptr + value and mask must be one contiguous physical chunk + base = pto.addptr destination, offset + updated_base = pto.vsstb value, base, block_stride, repeat_stride, mask + The updated base result is intentionally unused by VMI lowering, but the + post-update VPTO form matches CCE block-strided staging behavior. + unsupported cases: + multi-chunk value or mask + non-contiguous layouts + memref/gm destination + +vmi.gather: + semantics: + if mask[lane] is true, result[lane] = memory[base + indices[lane]] + if mask[lane] is false, result[lane] = passthru[lane] and no memory read occurs for that lane + indices are interpreted in element units, not bytes + layout assignment: + result natural layout is contiguous + indices and passthru uses are requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + source must be !pto.ptr + supported 32-bit mode: + T must be a 32-bit element type + indices must be signless or unsigned i32 + result / indices / passthru / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + gathered_i = pto.vgather2_bc source, indices_i, mask_i + result_i = pto.vsel gathered_i, passthru_i, mask_i + supported ui16 mode: + T must be ui16 + indices must be unsigned i16 + result / indices / passthru / mask must be one contiguous physical chunk + mask granularity must be b16 + gathered = pto.vgather2 source, indices, mask + result = pto.vsel gathered, passthru, mask + VPTO LLVM emitter bitcasts the physical index register from <128xi16> + to the installed Bisheng intrinsic ABI <64xi32>; this is the same + 256B register payload viewed as the wrapper-level vector_u16 index + container. + reason for vsel: + VPTO gather false predicate lanes do not read memory but produce zero; VMI false lanes preserve passthru. + unsupported cases: + f16/b16/f8/i8 result element types + partial/tail chunks + non-contiguous layouts + memref/gm source + guarded/scratch fallback + +vmi.scatter: + semantics: + if mask[lane] is true, memory[base + indices[lane]] = value[lane] + if mask[lane] is false, no memory write occurs for that lane + indices are interpreted in element units, not bytes + all active lanes must have pairwise-distinct indices; duplicate active indices violate the VMI scatter contract + layout assignment: + value and indices uses are requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct path: + destination must be !pto.ptr + T must be a 32-bit element type + indices must be signless or unsigned i32 + value / indices / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + pto.vscatter value_i, destination, indices_i, mask_i + unsupported cases: + f16/b16/f8/i8 value element types + partial/tail chunks + non-contiguous layouts + memref/gm destination + ordered duplicate-index fallback + +vmi.expand_load: + semantics: + k = 0 + for lane in logical order: + if mask[lane]: + result[lane] = memory[base + k] + k += 1 + else: + result[lane] = passthru[lane] + layout assignment: + result natural layout is contiguous + passthru use is requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + static all-active path: + pto.vmi.create_mask with constant active_lanes >= logical lane count + dense all-true pto.vmi.constant_mask + in that case expand_load degenerates to ordinary vmi.load: + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = loaded_i + partial/tail chunks still require the same statically safe full-read proof as vmi.load. + runtime-mask path: + source must be !pto.ptr + T must be a 32-bit element type + result / passthru / mask must be contiguous one full physical chunk + mask granularity must be b32 + base_i = pto.addptr source, offset + indices_i = pto.vusqz(zero_i32_carrier, mask_i) + loaded_i = pto.vgather2_bc base_i, indices_i, mask_i + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + runtime masks across multiple physical chunks + runtime masks on non-32-bit element types + non-contiguous layouts + unsafe partial/tail read footprints + guarded load or scratch fallback + +vmi.masked_store: + semantics: + if mask[lane] is true, store value[lane] + if mask[lane] is false, no memory write occurs for that logical lane + current full-footprint path: + value and mask are requested as contiguous at the use site + mask granularity is derived from value element width + for each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks store with the user mask directly + the final chunk computes tail_valid_i with pto.plt_b8/b16/b32(valid_lanes) + store_mask_i = pto.pand user_mask_i, tail_valid_i, all_true_mask_i + pto.vsts tail_value_i, base[offset_i], store_mask_i + padding lanes and user-inactive lanes therefore both have no write effect. + If the incoming value/mask are deinterleaved, layout assignment inserts + ensure_layout/ensure_mask_layout or the vmi-to-vpto pattern materializes the same contiguous representation before + emitting stores. This preserves logical memory order and keeps inactive lanes write-free. + +non-full chunks: + vmi.store and vmi.masked_store support contiguous tail chunks by predicating the final pto.vsts with + a prefix valid mask. masked_store additionally ANDs the user mask with the tail-valid mask. + deinterleaved=2/4 tail store/masked_store is supported only through explicit layout materialization to + contiguous chunks first. This requires every deinterleaved part to have the same physical chunk count, so the + materializer can build complete vintlv/pintlv groups. After materialization, each contiguous chunk is predicated by + the logical tail-valid mask; chunks whose active logical lane count is zero are not emitted as stores. Uneven + deinterleaved groups, such as 129xf32 with deinterleaved=2, remain unsupported until a padding/scratch plan can + assemble only the observable contiguous chunks. + vmi.load support partial/tail chunks only when the direct full physical read is statically safe: + statically shaped memref source, constant non-negative offset, and enough elements for the + whole physical read footprint. Padding lanes must never become observable. Other partial/tail load cases still need + scratch/guarded/true-masked load planning. +``` + +Histogram lowering: + +```text +vmi.dhist semantics: + source lanes are ui8 samples + mask selects active source lanes + acc/result are complete logical 256-bin ui16 histograms + result[b] = acc[b] + count(active source lanes whose value equals b) + +layout assignment: + source layout = contiguous + mask layout = contiguous, granularity b8 + acc/result layout = contiguous !pto.vmi.vreg<256xui16> + +physicalization: + acc/result physical arity is 2 because 256xui16 is 512B + part0 represents logical bins 0..127 + part1 represents logical bins 128..255 +``` + +`vmi-to-vpto` lowering for `pto.vmi.dhist` is local and deterministic from the +op and assigned types: + +```text +lo = converted acc part0 +hi = converted acc part1 + +for each converted source physical chunk c in logical order: + chunk_mask = converted b8 mask chunk c + + if source chunk c contains padding lanes because N is not a multiple of 256: + valid = pto.pge/plt_b8 prefix mask for the valid logical lanes in this chunk + chunk_mask = pto.pand chunk_mask, valid + + lo = pto.dhistv2 lo, src_c, chunk_mask, #bin=0 + hi = pto.dhistv2 hi, src_c, chunk_mask, #bin=1 + +return physical result parts [lo, hi] +``` + +Required preflight: + +```text +acc/result element type is ui16 and logical lane count is exactly 256 +source element type is ui8 +source and mask logical lane counts match +source/mask are contiguous +mask granularity is b8 +source physical chunks are 256-lane ui8 chunks; final partial chunk is allowed +only when the lowering can construct the valid-lane prefix mask +``` + +Diagnostics: + +```text +VMI-UNSUPPORTED: pto.vmi.dhist requires contiguous ui8 source, b8 mask, and +contiguous 256xui16 accumulator/result + +VMI-UNSUPPORTED: pto.vmi.dhist final partial source chunk requires valid-lane +b8 mask materialization +``` + +`pto.vmi.chist` has the same verifier and assignment requirements, but final +lowering is capability-gated: + +```text +if CHISTv2 high-range semantics are verified as global cumulative: + replace the two pto.dhistv2 calls above with pto.chistv2 calls + +elif CHISTv2 high-range semantics are verified as range-local cumulative: + lower low/high pto.chistv2 and add the low-half total count to every high-half bin, + but only after low-total materialization and broadcast support is explicit + +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +Do not classify histogram as `group_reduce`. Its result location is selected +by source values, not by lane/group position, and its low/high split is caused +by the physical `128xui16` VPTO result width. + +Final hard gate: + +```text +no pto.vmi op remains +no !pto.vmi.* type remains, including in function signatures +no UnrealizedConversionCastOp remains +physical arity matches helper for every lowered value +``` + +Slice 4 完成条件: + +```text +1. `f16 -> f32 -> add -> store` lowers with deinterleaved=2 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +2. `f8 -> f32 -> add -> store` lowers with deinterleaved=4 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +3. Non-full memory physical arity and valid lane map are tested. + Covered by vmi_to_vpto_load_nonfull.pto, vmi_to_vpto_load_nonfull_memref.pto, + vmi_to_vpto_store_deint_invalid.pto, + vmi_to_vpto_load_safe_tail_memref.pto, + vmi_to_vpto_load_safe_tail_memref_negative_offset.pto, + vmi_to_vpto_masked_load_safe_tail_memref.pto, + vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_expand_load_all_active.pto, + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto, and multi-chunk load/store layout tests. +4. Full-footprint load/store direct path lowers through pto.vlds/pto.vsts or deinterleaved=2 x2 dist + instructions with offset 0. + Covered by the load/store direct-path and layout-folding tests. +5. Internal func.call boundaries expand callee signatures, call operands/results, and returned VMI values together. + Covered by vmi_layout_assignment_call_boundary.pto, vmi_layout_assignment_indirect_call_invalid.pto, + and vmi_to_vpto_call_boundary.pto. +6. Structured control-flow carrying VMI values expands iter args, yields, results, masks, and returns together. + Covered by vmi_layout_assignment_cf_switch.pto, + vmi_layout_assignment_scf_execute_region.pto, + vmi_layout_assignment_scf_index_switch.pto, + vmi_layout_assignment_scf_while.pto, vmi_to_vpto_cf_branch.pto, + vmi_to_vpto_scf_for.pto, vmi_to_vpto_scf_if.pto, and the user-facing + vmi_ptoas_cli_control_flow.pto. +7. Final gate rejects residual VMI helper and unrealized casts. + Covered by vmi_to_vpto_ensure_identity.pto, + vmi_to_vpto_ensure_layout_partial_invalid.pto, + vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto, + vmi_to_vpto_ensure_mask_layout_partial_invalid.pto, + vmi_to_vpto_unsupported_op_invalid.pto, + vmi_to_vpto_unrealized_cast_residual_invalid.pto, + vmi_to_vpto_type_attr_residual_invalid.pto, and per-feature unsupported + tests. +8. Same-family indirect memory ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_gather_scatter_shape_invalid.pto together with the existing gather/scatter positive and + per-feature negative tests. +9. Same-family reduction ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_reduce_shape_invalid.pto together with the existing reduce add/min/max positive and + per-feature tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto for narrow integer rejection and + vmi_to_vpto_reduce_addf_f16.pto for f16 floating-point reduction lowering. +10. Target-specific element contracts are checked before OneToN rewriting for direct VPTO ops. + Covered by vmi_to_vpto_bf16_arith.pto, vmi_to_vpto_math_element_type_invalid.pto, + vmi_to_vpto_cmp_select.pto, vmi_to_vpto_cmp_element_type_invalid.pto, + vmi_to_vpto_fma.pto, vmi_to_vpto_fma_element_type_invalid.pto, and + vmi_to_vpto_unary_math.pto for negf/absf/absi/sqrt/exp/ln/relu, plus + vmi_to_vpto_relu_element_type_invalid.pto. +11. Same-family mask logic ops lower through the physical mask granularity instead of assuming b32 masks. + Covered by vmi_to_vpto_mask_logic.pto for mask_and/mask_or/mask_xor/mask_not on b32 masks produced by + cmpf and on direct b8/b16 mask operands. +12. `pto.vmi.dhist` lowers one logical 256-bin histogram into two VPTO low/high + bin-range histogram accumulator chains, and tail source chunks are masked + with a valid-lane b8 prefix. `pto.vmi.chist` is rejected until the target + CHISTv2 cumulative range semantics are classified. + Covered by vmi_to_vpto_dhist.pto, vmi_to_vpto_dhist_tail_mask.pto, and + vmi_to_vpto_chist_semantics_invalid.pto. +``` + +## 7. Slice 5: Memory Padding + +The Slice 4 direct path lowers `pto.vmi.load` through plain `pto.vlds` when the +memory source itself is supported and the element type has a known physical lane +width. This includes non-full logical vectors; the operation is treated as a +direct full physical read of the selected VPTO chunk(s). Masked/expand/gather +read-like operations still use the richer access plan because their masks or +lane maps carry additional semantic constraints. + +Implement an internal `VMIMemoryAccessPlan`: + +```text +base +logical lane count +logical_shape +permutation_map +lane-to-address map in element units +validMask +paddingValue +safeReadProof +writeMask +target capability decision +fallback resource decision +``` + +Current implementation status: + +```text +lib/PTO/Transforms/VMIToVPTO.cpp + VMIMemoryAccessPlan + VMIMemorySafeReadProof + VMIMemoryLogicalShape + VMIMemoryLaneAddressMap + VMIMemoryFallbackDecision + +currently routed through the plan: + contiguous identity logical_shape/permutation/lane-to-address map in element units + explicit rejection of non-identity memref layouts until subview/affine lane maps are represented + covered by vmi_to_vpto_memref_layout_invalid.pto, including a memref.subview-produced strided view + subview diagnostics name the missing normalized base/offset/stride lane-to-address plan + target true masked/non-faulting load capability query + current result is missing capability because pto.vlds has no mask operand + covered by vmi_to_vpto_masked_load_nonfull_invalid.pto + stable gather masked-load option + covered by vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto + currently emits a TODO diagnostic instead of lowering through VGATHER2 + direct pto.vmi.load source/layout capability check for full physical reads + pto.vmi.masked_load partial/tail safe full-read proof + pto.vmi.expand_load static all-active safe full-read proof + VMI-to-VPTO rewrite match guard for supported direct load sources + pto.vmi.store direct write target decision with all-true writeMask kind + pto.vmi.masked_store direct write target decision with explicit writeMask kind + unsafe masked/expand partial/tail read fallback decision as RequiredUnavailable diagnostic + covered by vmi_to_vpto_masked_load_nonfull_invalid.pto and + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto + +currently not implemented by the plan: + paddingValue materialization (intentionally unsupported in the first implementation stage) + non-all-true validMask direct masked/non-faulting load lowering + scratch/guarded fallback lowering or allocation + lowering for non-identity logical_shape/permutation_map/lane-to-address maps, including subview or affine lane maps + writeMask fallback planning beyond the existing contiguous tail-store predicate path +``` + +Important first-stage contract: + +```text +VMI physical tail lanes and transfer paddingValue are different concepts. + +Physical tail lanes: + arise because pto.vreg is fixed at 256 bytes + are outside the logical VMI lane count + may be read/computed only when the extra lanes remain unobservable + +transfer_read-style paddingValue: + is an observable logical result for invalid/OOB transfer lanes + cannot be dropped or replaced by arbitrary physical tail contents + is not materialized by the first-stage VMI implementation + +Therefore any frontend path that still needs transfer_read paddingValue +semantics must stop before direct VMI-to-VPTO lowering with VMI-UNSUPPORTED, +unless it has already canonicalized to an all-valid load/masked_load subset +whose invalid lanes are proven absent. +``` + +Read-like memory decision tree: + +```text +safeReadProof full && validMask all true: + direct load + +safeReadProof full && validMask not all true: + first-stage: VMI-UNSUPPORTED because paddingValue materialization is not implemented + future: full load + padding materialization + select + +target true masked/non-faulting load: + first-stage: VMI-UNSUPPORTED because true masked/non-faulting load and paddingValue materialization are not implemented + future: masked load + padding materialization + +otherwise: + first-stage: VMI-UNSUPPORTED with the missing fallback reason + future: split safe regions, scratch fill/copy/load, guarded fallback, or diagnostic +``` + +Write-like memory decision tree: + +```text +writeMask all true && full footprint safe-writable: + direct store + +target true masked store: + masked store + +otherwise: + split/guarded/scatter-like fallback or diagnostic +``` + +Slice 5 完成条件: + +```text +1. Unsafe partial/tail read-like ops never lower to a potentially invalid full + read unless the physical footprint is statically proven safe. +2. PaddingValue materialization is not required in the first implementation + stage. Any path that would require paddingValue, true masked/non-faulting + load, scratch fill/copy/load, or guarded fallback must report + `VMI-UNSUPPORTED` with the missing fallback reason. +3. Non-identity logical_shape/permutation_map/lane-to-address maps, including + subview or affine lane maps, are explicitly rejected before lowering. +4. Store-like partial/tail writes are supported only by the existing + full-chunk or contiguous/deinterleaved tail-store predicate paths. Other + writeMask fallback paths must report `VMI-UNSUPPORTED`. +``` + +## 8. Target Capabilities And Layout Fact Helpers + +Keep target capabilities separate from layout assignment policy. The shared +helpers expose target support and small layout/materialization facts; they do +not select a global lowering plan and are not a shared lowering-plan registry +between assignment and VMI-to-VPTO. + +```text +supportsElementType(type, purpose) +getPreferredCastLayoutFact(sourceType, resultType) +getPreferredGroupReduceLayoutFact(sourceType, numGroups) +canMaterializeDataLayout(sourceType, resultType) +canMaterializeMaskLayout(sourceType, resultType) +supportsMaskGranularityConversion(srcG, dstG) +supportsMemoryAccessProof(proof) +supportsPrefixPopcount(maskType) +supportsReductionScanContract(op) +getScratchResource(plan) +``` + +Capability and materialization helpers return structured results: + +```text +supported +unsupported_missing_capability +unsupported_disabled_by_option +unsupported_resource +``` + +Diagnostics must expose that reason. A pass must not silently choose scalar fallback when fallback is disabled. + +Current implementation status: + +```text +include/PTO/Transforms/VMITargetCapabilities.h + VMITargetCapabilityRegistry + VMICapabilityResult { status, reason } + +currently routed through the registry: + element-type purpose checks for predicate-maskable vregs and direct elementwise/cmp/fma/relu VPTO lowering + reduction-family element-type contracts for reduce_addi/reduce_addf/reduce_maxf/reduce_minf + direct pto.vlds/vsts memory source/destination support + missing target true masked/non-faulting load capability for the current pto.vlds surface + pointer-only UB memory support for pto.vgather2_bc/pto.vscatter/pto.vstur based VMI paths + supported source/result layout conversion pairs + supported b8/b16/b32 mask granularity conversion pairs + pto.vmi.channel_split/channel_merge supported channel count + pto.vmi.dhist direct target support and pto.vmi.chist cumulative range semantics classification + +still legacy helper-based and should migrate into the registry as follow-up: + full layout materialization plans and padding-safety checks + adjacent ppack/punpack mask granularity materialization plans + prefix popcount and full reduction/scan/contract shape capability checks +``` + +## 9. Diagnostics + +Centralize diagnostic codes in one header or utility file: + +```text +VMI-UNSUPPORTED +VMI-LAYOUT-CONTRACT +VMI-PASS-INVARIANT +VMI-RESIDUAL-OP +``` + +Current implementation defines these codes and their `": "` prefixes in `include/PTO/IR/VMIUtils.h`. Transform and +CLI code must reference those constants instead of spelling the diagnostic code strings locally; a source grep for the +four code strings should find only the central definitions. + +Every diagnostic should include: + +```text +source op +logical VMI type +producer natural layout, if any +consumer required layout, if any +missing capability or disabled option +available materialization paths, if known +``` + +## 10. Lit Test Layout + +Use a dedicated directory: + +```text +test/lit/vmi/ +``` + +Minimum test files: + +```text +vmi_type_attr_parse.mlir +vmi_type_attr_invalid.mlir +vmi_op_verifier_basic.mlir +vmi_producer_boundary.mlir +vmi_layout_assignment_widen.mlir +vmi_layout_assignment_cfg.mlir +vmi_layout_assignment_broadcast_remat.mlir +vmi_layout_assignment_iota_remat.mlir +vmi_layout_assignment_mask_remat.mlir +vmi_to_vpto_deinterleaved2.mlir +vmi_to_vpto_deinterleaved4.mlir +vmi_to_vpto_compaction_deint_invalid.mlir +vmi_to_vpto_load_safe_tail_memref.mlir +vmi_to_vpto_masked_load_safe_tail_memref.mlir +vmi_to_vpto_store_tail.mlir +vmi_to_vpto_dhist.mlir +vmi_to_vpto_dhist_tail_mask.mlir +vmi_to_vpto_chist_semantics_invalid.mlir +vmi_pipeline_hard_gates.mlir +``` + +Each pass test must use `FileCheck` to prove both positive output and negative absence: + +```text +CHECK: pto.vmi.addf +CHECK-NOT: pto.vadd +CHECK-NOT: unrealized_conversion_cast +``` + +Final lowering tests must check: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +## 11. Implementation Order + +Recommended merge order: + +```text +1. VMI type/attr + helper + parse/verify tests. +2. Slice 1 op shells + verifier tests. +3. VMI producer boundary verifier. +4. layout assignment for straight-line code. +5. layout assignment for scf/cf/function boundaries. +6. vmi-to-vpto type conversion + pack/unpack/unpackable block args. +7. deinterleaved=2 f16 widen end-to-end. +8. deinterleaved=4 f8 widen end-to-end. +9. load/store padding-safe lowering. +10. remaining semantic op families. +``` + +Do not merge a pass that leaves hidden side tables as a required interpretation mechanism. Temporary internal +analysis structures are fine only if the pass materializes the final state into IR before returning. + +## 12. Review Checklist Before Coding Each Slice + +Before implementation: + +```text +1. Is the op/type syntax written in ODS and tested by parser round-trip? +2. Does every verifier rule have a negative test? +3. Does every pass have a post-pass hard gate? +4. Are CFG block arguments and function signatures covered? +5. Does any lowering rely on a defining op that block arguments do not have? +6. Does memory lowering prove safe footprint separately from valid lane mask? +7. Does mask granularity follow consumer element width? +8. Does final VPTO lowering leave zero VMI op/type/helper or unrealized-cast residuals? +``` + +If any answer is no, the slice is not ready to be treated as complete. + +## 13. Adding One VMI Op End To End + +新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的七个落点, +否则很容易出现 verifier 能过、layout pass 不知道怎么约束、或控制流 physicalization 后残留 VMI type。 + +```text +1. ODS surface: + include/PTO/IR/VMIOps.td + +2. semantic verifier: + lib/PTO/IR/VMI.cpp + +3. layout assignment facts: + lib/PTO/Transforms/VMILayoutAssignment.cpp + +4. shared layout support, when the fact crosses stages: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + +5. vmi-to-vpto preflight: + lib/PTO/Transforms/VMIToVPTO.cpp::verifySupportedVMIToVPTOOps + +6. OneToN lowering pattern: + lib/PTO/Transforms/VMIToVPTO.cpp::populateVMIOneToNConversionPatterns + +7. focused lit tests: + test/lit/vmi/ +``` + +这七个落点的职责不同: + +```text +ODS: + 只定义 op 形状、operand/result type 类别、assembly format、interface 和 verifier hook。 + +VMI.cpp verifier: + 检查局部语义,例如元素类型、rank、lane count、predicate 字符串、source/result bit 数关系。 + 不能依赖 def-use 图,不能决定 layout。 + +LayoutAssignment: + 只收集 value-level layout/granularity 事实: + - producer natural layout + - operands that must share layout with result + - consumer required layout + - mask consumer required granularity + 不能在 collect 阶段改 IR。 + +VMILayoutSupport: + 只放跨 assignment、validation、optimization、lowering 中至少两个阶段共享的纯查询。 + 典型内容是 cast layout fact、group_reduce layout fact、ensure_* materialization support。 + 不能返回 VPTO instruction sequence、不能决定 clone/rematerialize、不能读取 producer/user context。 + 只有一个 lowering pattern 自己使用的判断不要抽到这里。 + +VMIToVPTO preflight: + 在 rewrite 前拒绝当前 lowering 不支持但语义合法的 case。 + 典型例子是 partial physical chunk、non-prefix mask constant、dynamic create_mask、unsupported shuffle。 + +OneToN pattern: + 从 adaptor 读取 physical parts,按已经确定的 layout 发 VPTO op。 + 不能重新推断 layout,也不能通过 defining op 找 physical parts。 + +lit: + 至少覆盖 parser/verify、layout assignment、positive lowering、negative unsupported diagnostic。 +``` + +### Layout Fact Template + +新增 op 时先给它归类,再写 layout 约束。不要从 VPTO 指令形态反推 VMI layout;layout 的来源必须是 +logical vector 语义和当前物理指令的天然限制。 + +```text +elementwise same-shape op: + examples: + addf/addi/subf/mulf/andi/shli/shrui/absf/absi/sqrt + layout rule: + all data operands and result are in one equivalence class + lowering rule: + emit one VPTO op per physical part + +compare op: + examples: + cmpf/cmpi + layout rule: + lhs/rhs data layout unified + result mask requested to the same data layout + result mask granularity comes from lhs/rhs element width + lowering rule: + emit one vcmp per data part, producing corresponding mask part + +mask logical op: + examples: + mask_and/mask_or/mask_xor/mask_not + layout rule: + all mask operands/results share layout and granularity + lowering rule: + emit one predicate op per physical mask part + +layout-changing producer: + examples: + extf f16->f32, extf f8->f32, truncf f32->f16, truncf f32->fp8-like + layout rule: + source/request side follows instruction input contract + result natural layout follows instruction output contract + lowering rule: + emit the instruction sequence that preserves logical lane order under that layout + +memory consumer/producer: + examples: + load/store/load/store + layout rule: + load result natural layout is chosen by memory dist capability + store value operand requests the layout that memory dist can consume + lowering rule: + direct path only when every physical chunk has no padding lane and footprint is safe + +structural boundary: + examples: + scf.if result/yield, scf.for iter args, cf.br successor operands, func.call + layout rule: + semantically identical incoming/outgoing values are unified + lowering rule: + handled by OneToN structural patterns, not by op semantic lowering +``` + +代码里 `LayoutSolver::addConstraints()` 应该只表达上面的事实。例如一个普通 elementwise binary op +只需要: + +```cpp +if (auto addf = dyn_cast(op)) { + if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || + failed(unite(addf.getLhs(), addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +一个 layout-changing op 不应该把 source/result 直接 `unite`,而是明确写 producer/consumer 合同: + +```cpp +if (auto extf = dyn_cast(op)) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, factor), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +### OneToN Pattern Template + +`vmi-to-vpto` pattern 的输入不再是 logical VMI value,而是 adaptor 里已经 flatten 好的 physical parts。 +pattern 只做三件事: + +```text +1. 从 adaptor 取每个 logical operand 的 physical part list。 +2. 从 resultMapping 取每个 logical result 对应的 physical result type list。 +3. 按 part 顺序创建 VPTO op,并用 resultMapping replace 原 op。 +``` + +普通 elementwise binary op 的代码形态应该接近: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhsParts.size() != rhsParts.size() || lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [lhs, rhs, resultType] : llvm::zip_equal(lhsParts, rhsParts, resultTypes)) + results.push_back(rewriter.create(op.getLoc(), resultType, lhs, rhs)); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这里不能调用 `op.getLhs().getDefiningOp()` 去找物理寄存器。原因是 VMI value 可以来自: + +```text +function argument +block argument +scf.for iter arg +scf.if result +cf.br successor argument +func.call result +``` + +这些 value 很多没有 VMI defining op。physical parts 的唯一合法来源是 OneToN adaptor 和 +OneToNTypeMapping。 + +### Control-Flow Checklist + +每新增一个 op,不一定要写新的控制流 pattern;但必须检查它的结果或 operand 是否可能跨边界。 +如果只是普通 VMI value,那么已有 structural OneToN pattern 应该负责边界 physicalization: + +```text +func.func / func.call / func.return: + upstream func OneToN conversion + +scf.if / scf.for / scf.while / scf.yield: + upstream SCF OneToN structural conversion plus layout solver equivalence constraints + +cf.br / cf.cond_br / cf.switch: + project-local OneToN patterns flatten successor operands and rewrite destination block signatures + +scf.execute_region / scf.index_switch: + project-local OneToN patterns flatten region results +``` + +新增 op 的测试要至少放一个跨边界用例,证明 op 的 result 不是只在 straight-line IR 中工作: + +```mlir +%r = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.addf %a, %b : ... -> !pto.vmi.vreg<128xf32> + scf.yield %x : !pto.vmi.vreg<128xf32> +} else { + scf.yield %c : !pto.vmi.vreg<128xf32> +} +pto.vmi.store %r, %ptr, %off : ... +``` + +对应 lowering test 必须检查: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: !pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +如果这个测试失败,通常不是该 op 的 VPTO pattern 本身错,而是 layout assignment 没有把 yield/result/consumer +约束统一,或者 OneToN structural pattern 漏了某种 region/control-flow op。 + +### Preflight Versus Pattern Failure + +语义合法但当前还没有物理实现的 case,应该在 `verifySupportedVMIToVPTOOps()` 里给稳定 diagnostic, +不要让 pattern 随机 `notifyMatchFailure()` 后落成 generic conversion failure。 + +```text +use verifier failure: + op 本身语义非法,任何 target 都不应该接受。 + examples: + absf on integer element + shrui on signed integer element + bitcast total bits mismatch + +use VMI-LAYOUT-CONTRACT: + 多个 producer/consumer/control-flow 约束互相冲突。 + examples: + one value simultaneously required as contiguous and deinterleaved=2 + one mask simultaneously required as b16 and b32 + +use VMI-UNSUPPORTED in preflight: + VMI semantics are valid, but current VPTO materialization is not implemented. + examples: + partial/tail memory access + pred-only constant mask without concrete b8/b16/b32 granularity + shuffle that requires vselr index-vector materialization + bitcast with mismatched layouts or per-chunk logical bit footprints + +use VMI-RESIDUAL-OP: + conversion framework finished but VMI op/type/helper/cast remains. + This is a pass bug or missing pattern, not a user semantic error. +``` + +Pattern-local `notifyMatchFailure()` is still useful for debugging competing patterns, but it must not be the only +user-visible explanation for a known unsupported VMI semantic case. diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md new file mode 100644 index 0000000000..089120a4f8 --- /dev/null +++ b/docs/designs/vmi-introduction.md @@ -0,0 +1,1076 @@ +# VMI 介绍 + +本文介绍 VMI 的设计入口:VMI 解决什么问题,layout 有哪些,pass pipeline +如何分工,以及这些机制分别应对哪些典型场景。更完整的逐 case lowering 结果见 +`docs/designs/vmi-layout-lowering-cases.md`。 + +示例是设计级 IR,保留关键 type、layout、helper op 和 VPTO op 形状, +省略 module wrapper、完整 operand list 和不影响讨论的 SSA 细节。 + +## 1. VMI 表达什么 + +VMI 是 VPTO 之前的逻辑向量层。它让前端先表达“我要对 `NxT` 的逻辑向量做什么”, +再由 layout assignment 决定这个逻辑向量如何拆到 256B 物理 vector register 上。 +当 VPTO 指令因为物理 register 宽度只能暴露半宽接口时,VMI 也负责提供完整的 +逻辑语义。例如 `ui8` histogram 的完整结果是 `256xui16`,物理 VPTO histogram +一次只能返回 `128xui16`;VMI surface 应该表达完整 histogram,low/high bin +range 拆分属于 lowering 细节。 + +Surface VMI 类型不携带布局: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.mask<128xpred> +``` + +Layout-assigned VMI 类型携带具体布局和 mask granularity: + +```mlir +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +VMI 的核心约束是:`vmi-to-vpto` 只从当前 op 的 attrs、operands、types、 +layouts 和显式 helper ops 做 lowering,不读取隐藏 plan/recipe,也不通过 +defining op 或 sibling user 恢复上下文。 + +## 2. Layout 类型 + +### 2.1 `contiguous` + +```mlir +#pto.vmi.layout +``` + +含义:logical lane 按顺序落入物理 register list。 + +```text +logical lanes: 0 1 2 ... 63 | 64 65 ... 127 +physical part: p0 | p1 +``` + +典型场景: + +```text +dense load/store +普通 elementwise compute +一个 group 天然适配当前 reduce op 时的 reduction input +caller/callee 约定 dense order 时的 control-flow/function boundary +``` + +### 2.2 `deinterleaved = F, block_elems = B` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` 缺省为 `1`。逻辑 lane 到物理 part 的映射是: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +`deinterleaved=2` 的直观例子: + +```text +logical lanes: 0 1 2 3 4 5 ... +physical part0: 0 2 4 ... +physical part1: 1 3 5 ... +``` + +`deinterleaved=4, block_elems=8` 的直观例子: + +```text +logical group S=32: + lanes 0.. 7 -> part0 lanes 0..7 + lanes 8..15 -> part1 lanes 0..7 + lanes 16..23 -> part2 lanes 0..7 + lanes 24..31 -> part3 lanes 0..7 +``` + +典型场景: + +```text +f16 -> f32: + vcvt 天然产生 even/odd 两个 f32 part,所以结果使用 deinterleaved=2。 + +f32 -> f16: + vcvt 需要 f32 source 先拆成 even/odd 两个 part,所以 source 使用 + deinterleaved=2。 + +S=32 group_reduce f32: + 一个 group 有 32 个 f32 element。高效 reduce path 消费四个 8-lane block, + 所以 source/mask 使用 deinterleaved=4, block_elems=8。 +``` + +`block_elems=8` 表示一种按 32B row fragment 组织的输入形态,不表示 +S=32 reduce 只能接受这一种形态。如果同一个 value 还要服务 narrow cast 等 +element-parity consumer,assignment 可以选择 `deinterleaved=4, block_elems=1` +作为共同 layout,再由 lowering 生成对应的物理指令序列。 + +`deinterleaved` 只描述最终物理 part 中有哪些 logical lane,不描述这个 layout +由哪条指令生成。不同 producer 可以用不同方式直接产生同一个 layout;如果不能 +直接产生,后续 lowering 再通过显式 materialization helper 把 source layout +转换成 consumer 需要的 layout。具体 lowering 形状见 case catalog。 + +### 2.3 `num_groups = G, slots = K` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +这是 group-slot result layout。它不表示全部 `N` 个 logical lane 都有语义值。 +只有 `G` 个 group 结果 slot 有语义值。 + +```text +slot_block(g) = g / K +slot_lane(g) = (g % K) * lane_stride + +physical part slot_block(g) 的 lane slot_lane(g) 保存 group g 的结果 +``` + +`lane_stride` 缺省为 1,单位是 logical element-sized physical slot。 +它描述 group result 在物理存储中的固定间距,不改变 VMI 的逻辑元素类型。 +例如 `ui8 lane_stride=4` 表示 group slot 存在 byte lane 0, 4, 8, ... +这种形态可以 lower 为 `PK4_B32` store,物理上使用 b32 carrier 的 low byte。 + +`num_groups=16, slots=8` 的例子: + +```text +part0 lane0..7 = group result 0..7 +part1 lane0..7 = group result 8..15 +other lanes = 对普通 dense consumer 来说未定义 +``` + +为什么 group 信息也要放进 layout: + +```text +group_reduce 自身有 num_groups,但它的结果可能继续跨过 truncf、 +group_broadcast、group_store、scf.if、scf.for、function call 或多个 consumer。 + +这些后续 op 不应该回看 producer attr。value layout 因此需要记录有多少个 +group result,以及这些 result 如何 packed 到 physical slot。 +``` + +典型场景: + +```text +group_reduce result +group_slot_load result +group_store input +group_broadcast input +group-slot control-flow/function boundary +部分 row-local cast 路径,通常使用 slots=1 +``` + +## 3. Pass Pipeline + +```text +pto-validate-vmi-ir + -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto +``` + +### 3.1 `pto-validate-vmi-ir` + +检查 surface VMI 边界。 + +合法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> +``` + +非法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +原因:具体 layout 由 `vmi-layout-assignment` 产生,不应该由 surface frontend +提前写入。 + +### 3.2 `vmi-layout-assignment` + +这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, +并在 layout 不匹配的 use site 插入显式 helper op。 + +这个 pass 的工作顺序是固定的: + +```text +1. 做少量 VMI 内部规整,让后续 layout 规则面对稳定形态。 +2. 为 data value 建 union-find 求解器,并收集 data 约束和 data use request。 +3. 把可采纳的 consumer request 提升为 producer/result 的最终 layout。 +4. 改写所有 data value type,让 !pto.vmi.vreg 携带具体 layout。 +5. 对仍不匹配的 data use 插入 pto.vmi.ensure_layout。 +6. 基于已经确定的 data layout 推导 mask layout 和 predicate granularity。 +7. 改写所有 mask type,并对不匹配的 mask use 插入 ensure_mask_*。 +8. 同步更新 function type、call boundary 和 block argument type。 +9. 校验 layout-assigned VMI IR。 +``` + +Data 和 mask 分两轮求解。原因是 mask layout 通常依赖对应 data operand 或 result +的 layout;例如 `cmpf` 产生的 mask 跟比较输入的 data layout 对齐, +`select`/`reduce`/`masked_load` 消费的 mask 也要跟对应 data value 的 lane +layout 和元素 bitwidth 对齐。 + +Data 求解器为每个 `!pto.vmi.vreg` 建一个节点: + +```text +DataNode: + value = 对应 SSA value + original type = surface VMI type + parent = union-find parent + naturalLayout = 当前等价类选择的自然 layout,可能为空 +``` + +遍历 IR 时,每个 op 向 data 求解器贡献三类信息。 + +第一类是 layout 等价约束。它表示几个 value 必须使用同一个 physical layout, +也就是 union-find 中的同一个等价类。典型来源: + +```text +layout-transparent elementwise: + addf/addi/subf/subi/mulf/muli/fma/divf/minf/maxf/... + L(operands...) = L(result) + +unary elementwise: + negf/absf/absi/sqrt/exp/ln/relu/not + L(source) = L(result) + +select: + L(true_value) = L(false_value) = L(result) + +bitcast: + L(source) = L(result) + +structured control flow: + scf.if result = then/else yield operand + scf.for result = init operand = iter_arg = yield operand + scf.while result = init/before/condition/after/yield carried value + +cf branch: + branch operand = destination block argument + +function boundary: + call operand = callee argument + call result = callee return operand + multiple returns of the same function agree per result index +``` + +这一步只说明“这些 value 如果存在布局,就必须一致”。它不等价于把某个 +consumer 的 request 无条件推过所有 producer 或控制流。 + +等价类可以画成“同一个框里的 value 共用一个 layout 变量”。例如普通 +elementwise 链: + +```text +surface VMI: + + %x = pto.vmi.load ... + %k = pto.vmi.broadcast ... + %y = pto.vmi.mulf %x, %k + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C0 + +--------------------------------------+ + | %x %k %y | + | load broadcast mulf result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +若 %y 的 producer chain 可采纳该 request,assignment 可以选择: + + L(C0) = deinterleaved=4 +``` + +控制流 join 也是等价类,但 request adoption 的含义不同: + +```text +surface VMI: + + %y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a + } else { + scf.yield %b + } + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C1 + +--------------------------------------+ + | %a %b %y | + | then yield else yield if result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +scf.if result 不是 consumer-driven adoption 的可采纳 producer。 +若 C1 不能直接选择 deinterleaved=4,assignment 保持 C1 的布局, +并在 use site materialize: + + %y_for_q = pto.vmi.ensure_layout %y : L(C1) -> deinterleaved=4 + %q = pto.vmi.truncf %y_for_q +``` + +多 consumer 冲突时,等价类仍然只有一个 layout: + +```text +surface VMI: + + %y = pto.vmi.mulf %x, %k + pto.vmi.store %y, %out0 + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C2 + +-----------------------------+ + | %x %k %y | + +-----------------------------+ + |\ + | \ use request from truncf: deinterleaved=4 + | + +--- use request from store: contiguous + +两个 use request 不一致时,不能让 %y 同时拥有两个 layout。 +baseline assignment 保留 C2 已有的 natural layout;若没有 natural layout, +则使用默认 contiguous。与该 layout 不匹配的 edge 会插 ensure_layout。 +``` + +第二类是 result 自然布局。某些 op 的结果本身有目标相关的自然布局: + +```text +普通 reduce / compress / shuffle: + result 通常是 contiguous。 + +group_reduce: + source 需要适配 group reduce 指令形态; + result 使用 group_slots(num_groups, slots) 描述 group-slot result。 + +cast: + widening/narrowing 根据 cast support 决定 source request 和 result layout。 + +group_load / group_slot_load / group_broadcast_load: + result 根据 group size、row stride 和目标能力选择 contiguous、deinterleaved + 或 group_slots。group_broadcast_load 表达“每个 logical group load 一个值并 + 广播到组内 lanes”的逻辑语义;E2B 只是兼容 layout 下的一种 lowering。 + +stride_load: + result 是 contiguous。block/repeat stride 只描述 memory address map, + 不改变 register 内 logical lane order。 + +active_prefix_index: + result 使用 contiguous。 +``` + +若同一个等价类已经有自然布局,再设置不同自然布局会报 layout contract 冲突。 + +第三类是 operand 使用请求。consumer 不直接修改 operand 的 type,而是记录 +“这个 use site 希望 operand 是什么 layout”: + +```text +store / masked_store value: + wants contiguous + +ordinary reduce source/init: + wants contiguous + +group_reduce source: + wants preferred group-reduce source layout + +group_store value: + wants preferred group result layout + +stride_store value: + wants contiguous。block/repeat stride 只描述 memory write address map, + 不表示 source vreg 是 lane-strided 或 NZ layout。 + +truncf/trunci/extf/extsi/extui source: + wants cast support 给出的 source layout + +channel_split / channel_merge / shuffle: + wants 各自 lowering 需要的 source/input layout +``` + +收集完这些信息后,assignment 才尝试做 consumer-driven adoption。它逐个查看 +use request:如果 operand 的 producer 可以直接用 consumer 需要的 layout 产生 +同一个逻辑向量,并且多 use 时所有 use 都请求同一个 layout,那么这个 request +会被提升为该 value 所在 data 等价类的最终 layout。 + +可采纳 producer 是受限集合: + +```text +load +broadcast / constant / iota +layout-transparent elementwise +select +bitcast +``` + +这就是 request 看起来能穿过 elemwise 的原因: + +```mlir +%x = pto.vmi.load ... +%k = pto.vmi.broadcast ... +%y = pto.vmi.mulf %x, %k +%q = pto.vmi.truncf %y +``` + +`mulf` 先把 `%x`、`%k`、`%y` 合成同一个 data 等价类。`truncf` 对 `%y` +的 source use 请求 `deinterleaved=4` 时,这个 request 作用到 `%y` 所在等价类; +因为 `mulf` 是可采纳 producer,assignment 可以把整个等价类选成 +`deinterleaved=4`,从而让 load/broadcast/mulf 直接在这个 layout 下产生数据。 + +控制流边界也会形成等价类,但它不是任意 request 的自动传播通道: + +```mlir +%y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a +} else { + scf.yield %b +} +%q = pto.vmi.truncf %y +``` + +`%y`、`%a`、`%b` 的 layout 必须一致;但 `scf.if` result 本身不是 +consumer-driven adoption 的可采纳 producer。若 `%q` 需要的 layout 无法成为 +这个等价类的最终布局,assignment 会在 `%q` 的 use site 插 +`pto.vmi.ensure_layout`,而不是隐式重写两个 branch 的内部计算。 + +Data layout 确定后,pass 会把每个 `!pto.vmi.vreg` 改写成 +`!pto.vmi.vreg`。如果某个记录过的 use request 仍然和 operand +当前 layout 不一致,pass 在该 consumer 前插显式 materialization: + +```mlir +%x_req = pto.vmi.ensure_layout %x + : !pto.vmi.vreg + -> !pto.vmi.vreg +consumer %x_req +``` + +这个规则也处理多 consumer 冲突: + +```mlir +%y = pto.vmi.mulf %x, %k +pto.vmi.store %y, %out0 // wants contiguous +%q = pto.vmi.truncf %y // wants deinterleaved=4 source +``` + +一个 SSA value 只能属于一个 data layout 等价类。若两个 use 不能共同满足, +baseline assignment 保留一个等价类 layout,并在不匹配 use 前插 +`ensure_layout`。后续 `vmi-layout-fold`、`vmi-layout-rematerialize` +和 `vmi-layout-sink-materialization` 可以在显式 helper op 上做优化,但 +`vmi-to-vpto` 不读取隐藏 plan 或 sibling user。 + +Mask 求解发生在 data type 改写之后。它同样维护 union-find 等价类,但节点记录 +两件事: + +```text +mask layout +predicate granularity: b8 / b16 / b32 +``` + +mask request 从已经带 layout 的 data value 推导: + +```text +cmpf/cmpi result: + mask layout = lhs data layout + granularity = lhs element bitwidth 对应的 predicate 粒度 + +select mask: + mask layout = result data layout + granularity = result element bitwidth 对应的 predicate 粒度 + +reduce / group_reduce / masked_load / expand_load mask: + mask layout = source/result data layout + granularity = 对应 data element bitwidth 的 predicate 粒度 +``` + +若 mask use 的 layout 或 granularity 不匹配,pass 显式插 +`pto.vmi.ensure_mask_layout` 或 `pto.vmi.ensure_mask_granularity`。 + +完成 data/mask 改写和 helper 插入后,pass 会同步更新 function type。直接 +internal call 会把 call operand/result 与 callee argument/return operand 合成 +同一布局约束;带 VMI type 的 external declaration 或 indirect call 没有可见 +body,当前需要显式 ABI materialization 设计,因此 layout assignment 会拒绝。 +这个阶段之后,IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type +上的 layout 和显式 `ensure_*` helper。 + +### 3.3 `vmi-layout-fold` + +当 consumer 可以直接保持同样的外部效果时,把显式 materialization 折进 +consumer。 + +变换前: + +```mlir +%dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +pto.vmi.store %dense, %dst[%off] +``` + +变换后: + +```mlir +pto.vmi.store %x, %dst[%off] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +``` + +可能的 VPTO 形状: + +```text +fold 前:vintlv + vsts + vsts +fold 后:vstsx2,使用交错 store mode +``` + +### 3.4 `vmi-layout-rematerialize` + +通过 clone 低成本、layout-polymorphic 的 producer 来替换 `ensure_*`。 + +变换前: + +```mlir +%s = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%s_split = pto.vmi.ensure_layout %s + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +变换后: + +```mlir +%s_split = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +预期可 rematerialize 的 producer: + +```text +splat constant +broadcast +iota +create_mask +create_group_mask +constant_mask +``` + +这个 pass 不 rematerialize: + +```text +load / masked_load / group_load / group_slot_load / group_broadcast_load +stride_load +reduce / group_reduce +control-flow results +``` + +### 3.5 `vmi-layout-sink-materialization` + +把匹配的 layout 转换跨过 layout-transparent elementwise op。 + +变换前: + +```mlir +%a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous +%b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous +%y_dense = pto.vmi.addf %a_dense, %b_dense : contiguous +``` + +变换后: + +```mlir +%y_split = pto.vmi.addf %a, %b : deinterleaved=2 +%y_dense = pto.vmi.ensure_layout %y_split : deinterleaved=2 -> contiguous +``` + +效果: + +```text +两个 input materialization -> 一个 result materialization +``` + +这个 pass 不会 sink 穿过 cast、load、store、reduce、group_broadcast 或 +control-flow op。 + +### 3.6 `vmi-legalize-arith-select` + +Canonicalization 可能把简单的 `scf.if` 折成 `arith.select`。VMI 希望把 +control-flow lowering 保持在结构化控制流里,所以这个 pass 会把 VMI value 上的 +`arith.select` 改回 `scf.if`。 + +```mlir +%r = arith.select %cond, %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +改成: + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + scf.yield %a : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} +``` + +### 3.7 `pto-validate-vmi-layout-ir` + +检查 post-assignment gate: + +```text +每个 VMI 数据值都有 concrete layout +每个 VMI mask 都有 concrete granularity 和 layout +helper op 有支持的 materialization path +semantic op/layout 组合有支持的 local lowering +vmi-to-vpto 之前没有物理 VPTO value 泄漏到 VMI IR 中 +``` + +非法例子: + +```mlir +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : ... -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr +``` + +原因: + +```text +dense store 不能把 group_slots 当 dense vector 读取。 +应使用 group_store、group_broadcast 或显式支持的 group-to-dense op。 +``` + +### 3.8 `vmi-to-vpto` + +把 layout-assigned VMI value 转换成有序物理 VPTO value 列表,并对每个 +VMI op 做 local lowering。 + +例子: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + part0 携带 even lanes,part1 携带 odd lanes + +!pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> 四个 physical part + part0 携带 group 0..7,part1 携带 group 8..15,... +``` + +`VMILayoutSupport` 不是 pass。它是 assignment、validation、optimization 和 +lowering 共享的查询库,用来避免重复实现 layout fact 和 supported +materialization 检查。 + +## 4. 典型场景 + +### 4.1 Dense Cast 与 Store + +```text +surface: + load f16,语义上连续 + extf 到 f32 + dense store f32 + +assignment: + load result = contiguous + extf result = deinterleaved=2 + store use = ensure_layout(deinterleaved=2 -> contiguous) + +baseline VPTO: + vlds + vcvt even / vcvt odd + vintlv + vsts + vsts + +fold-consumers 后的优化 VPTO: + vlds + vcvt even / vcvt odd + vstsx2,使用 interleaving store +``` + +这个场景说明为什么需要 `deinterleaved=2`,以及为什么 store-consumer folding +有价值。 + +### 4.2 Narrow Cast 与 Store + +```text +surface: + load f32 + truncf 到 f16 + dense store f16 + +assignment: + load result = deinterleaved=2 + truncf result = contiguous + +VPTO: + vldsx2 deinterleaving load + vcvt even / vcvt odd + vor + vsts +``` + +这个场景说明 memory op 可以直接产生 consumer 需要的 layout,但不需要保存隐藏 +plan。 + +### 4.3 一个 Producer 同时服务 Dense 和 Group Consumer + +```mlir +%x32 = pto.vmi.extf %x16 +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assignment 形状: + +```text +%x32 layout = deinterleaved=2 +group_reduce 直接消费 %x32 +dense store 获得 ensure_layout(%x32 -> contiguous) +``` + +VPTO 形状: + +```text +vcvt even/odd +vcgadd + vcgadd + vadd -> group_store result +vintlv + dense stores -> 产生 dense store 结果 +``` + +这个场景说明为什么需要 use-site materialization。producer 不需要选择一个能同时 +满足所有 consumer 的唯一 layout。 + +### 4.4 按 Group Size 区分的 Group Reduce + +对于 `N` 个 f32 lane 和 `G = num_groups`,group size 是 `S = N / G`。 + +```text +S=8: + input layout 可以是 contiguous。 + group_reduce result 通常使用 layout。 + +S=16: + 如果 input 来自 f16->f32 vcvt,layout 可以是 deinterleaved=2。 + 如果 input 从 dense 拆出,layout 可以是 deinterleaved=2, block_elems=8。 + result 通常使用 layout。 + +S=32: + input layout 使用 deinterleaved=4, block_elems=8。 + VPTO 形状是四个部分 group reduction 后接 add tree。 + result 通常使用 layout。 + +S=64: + row-local path 在可行时让每个 group 使用一条 physical row。 + result 可以使用 layout,避免 unsupported packing。 +``` + +S=32 例子: + +```text +assignment: + source/mask = deinterleaved=4, block_elems=8 + result = group_slots(num_groups=8, slots=8) + +VPTO: + vdintlv / pdintlv_b32 + vcgadd x4 + 使用 PAT_VL8 做 vadd tree + 通过一次 PAT_VL8 store 完成 group_store +``` + +这个场景说明为什么需要 `block_elems`。 + +### 4.5 Group Result 继续作为 Dense Rows 使用 + +Surface 意图: + +```mlir +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} +%rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} +%rows16 = pto.vmi.truncf %rows32 +pto.vmi.store %rows16, %dst[%off] +``` + +支持的 assignment 形状: + +```mlir +%sum32 = pto.vmi.group_reduce_addf ... + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +%rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rows32_split = pto.vmi.ensure_layout %rows32 + : contiguous -> deinterleaved=2 + +%rows16 = pto.vmi.truncf %rows32_split + : deinterleaved=2 -> contiguous + +pto.vmi.store %rows16, %dst[%off] +``` + +VPTO 形状: + +```text +group_reduce: + vcgadd partials + vadd tree + +group_broadcast: + vselr 风格 selection,把 group slots 展开到 dense row lanes + +truncf: + vcvt even/odd + merge + +store: + vsts +``` + +这个场景说明为什么 group 结果 layout 必须挂在 value 上:reduce 之后, +cast 和 broadcast 必须知道 group 结果在哪里,而不能回看 producer。 + +### 4.6 通过 Mask 表达 Tail + +VMI 通过 mask 表达 tail,不通过 padding 表达 tail。 + +```mlir +%mask = pto.vmi.create_mask %active_lanes +%x = pto.vmi.masked_load %src[%off], %mask +%y = pto.vmi.mulf %x, %scale +pto.vmi.masked_store %y, %dst[%off], %mask +``` + +Grouped tail: + +```mlir +%gmask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = 8, group_size = 32} +%sum = pto.vmi.group_reduce_addf %x, %gmask {num_groups = 8, reassoc} +``` + +同一个 semantic mask 面对 f8/f16/f32 user 时,可能需要不同 concrete +granularity。Assignment 会通过 mask helper op 显式表达这些转换。 + +### 4.7 控制流和函数边界 + +Concrete layout 必须显式跨过 CFG 和内部 function boundary。 + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + scf.yield %a_dense +} else { + %b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous + scf.yield %b_dense +} +``` + +`vmi-to-vpto` 之后,region result 会变成多个物理 VPTO value: + +```text +scf.if -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这个场景说明为什么 layout 应该是 type 的一部分,而不是依赖 defining op。 + +### 4.8 完整 Histogram 语义 + +VPTO 的 histogram 指令一次读取 `256xui8` source,但结果只能写 +`128xui16` accumulator。完整 `ui8` histogram 有 256 个 bin,因此物理 VPTO +接口需要通过 `#bin = 0/1` 分两次统计低半区和高半区。 + +VMI surface 不暴露这个物理 split: + +```mlir +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, + !pto.vmi.vreg, + !pto.vmi.mask + -> !pto.vmi.vreg<256xui16> +``` + +语义是完整 256-bin distribution histogram: + +```text +for b = 0..255: + hist[b] = acc[b] + count(i where mask[i] && src[i] == b) +``` + +Assignment 形状: + +```text +src/mask = contiguous, b8 mask granularity +acc/result = contiguous 256xui16 logical value +``` + +VPTO 形状: + +```text +acc/result part0 = bins 0..127 +acc/result part1 = bins 128..255 + +for each 256-lane source chunk: + part0 = dhistv2(part0, src_chunk, mask_chunk, #bin=0) + part1 = dhistv2(part1, src_chunk, mask_chunk, #bin=1) +``` + +这说明 VMI 的易用性不只来自 layout assignment。对于这种 value-indexed +accumulation,VMI 还应该隐藏 VPTO 为了物理 vreg 宽度暴露出来的 range +selector、lo/hi accumulator 和多条物理指令。 + +`pto.vmi.chist` 可以使用相同 surface 形状,但当前必须先验证 VPTO `CHISTv2` +在 high range 上返回的是全局累计还是 range-local 累计。这个差异会影响是否需要 +额外给 high half 加上 low half 的总计数,因此不能只按 op 名字猜 lowering。 + +### 4.9 Block-Strided UB Staging + +有些 CCE kernel 并不是在 register 内做任意 byte shuffle,而是先把结果写到 +UB scratch,再用 block-strided vector load/store materialize 目标 UB layout。 +`quant_minimum` 的 MXFP8 NZ case 是典型例子: + +```text +compute: + row-major ND FP8 scratch + +row-wise staging: + for row in 0..31: + q8_row = vmi.stride_load(nd + row * 64, + block_stride=1, repeat_stride=1) + vmi.stride_store(q8_row, nz + row * 32, + block_stride=33, repeat_stride=1) + +copy-out: + 2D MTE copies two 1024B NZ planes from UB to GM +``` + +这里 `q8_row` 的 VMI value 仍然是 contiguous `64xf8` 逻辑向量: + +```mlir +%q8_row = pto.vmi.stride_load %nd[%nd_off], %c1_i16, %c1_i16, %mask + : !pto.ptr, i16, i16, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf8E4M3FN> + +pto.vmi.stride_store %q8_row, %nz[%nz_off], %c33_i16, %c1_i16, %mask + : !pto.vmi.vreg<64xf8E4M3FN>, !pto.ptr, i16, i16, + !pto.vmi.mask<64xpred> +``` + +Assignment 形状: + +```text +stride_load result = contiguous +stride_load mask = contiguous, granularity follows result element width +stride_store value = contiguous +stride_store mask = contiguous, granularity follows value element width +``` + +VPTO 形状: + +```text +base_in = pto.addptr nd, nd_off +q8_row = pto.vsldb base_in, block_stride=1, repeat_stride=1, mask + +base_out = pto.addptr nz, nz_off +updated = pto.vsstb q8_row, base_out, block_stride=33, repeat_stride=1, mask + -> updated_base +``` + +这个场景说明:memory layout transformation 不一定要变成 VMI data layout。 +只要 VMI op 的语义是“从哪些地址读/写哪些 logical lane”,register value +仍然可以保持 contiguous,`vmi-to-vpto` 也仍然是 local lowering。 + +## 5. 当前边界 + +当前设计方向: + +```text +surface VMI: + 描述不带 layout 的逻辑向量语义。 + +layout assignment: + 选择 layout、mask granularity 和显式 materialization helper。 + +optimization: + 只在结果 IR 仍然可以 local lowering 时改写显式 helper。 + +vmi-to-vpto: + 严格 lower 它看到的 assigned/optimized IR。 +``` + +暂不支持或有意收紧的范围: + +```text +group_slots value 的普通 dense store: + 非法,除非先经过 group_broadcast 或其他显式 group-to-dense op。 + +packed group_slots f32->f16 cast: + 非法,除非 assignment 能把它 commute 到 group_broadcast 之后,或者使用 + 支持的 row-local slots=1 path。 + +FP4 packed input/output: + packed FP4 不属于当前 VMI surface。PTO/VPTO 已有 !pto.f4E1M2x2 + 和 !pto.f4E2M1x2 packed 物理类型,且这些类型的 shape 语义是 + packed pair/byte 数,不是 logical FP4 lane 数。在 VMI 中直接写 + vreg 会让 N 表示物理 packed byte 还是逻辑 FP4 元素 + 产生歧义,因此 verifier 会直接拒绝 + vmi.vreg<...x!pto.f4E1M2x2/!pto.f4E2M1x2>。 + + 当前 VMI surface 不包含专用 FP4 packed-memory op。FP4 packed IO + 需要先作为独立语义重新设计,不能进入当前 dialect surface。 + +extract: + 暂不作为支持的 VMI surface。 + +padding transfer_read: + 当前 tail 设计不需要;tail 使用 mask。 + +scan / contract / compress / active_prefix_index: + dialect surface 中可以存在,但除非补充具体 case,否则不属于第一阶段聚焦的 + layout/lowering 实现集合。 + +gather / scatter: + 当前只覆盖 UB pointer、contiguous layout 和已明确支持的 element/index 宽度。 + `ui16` gather 可承接 E8M0 byte-pair reorder;它不是通用 byte shuffle。 +``` + +设计目标是优先保证语义完整:只要 VMI 接受某个 case,所需的 layout 沟通就必须 +在 IR 中显式表达,并且能被 `vmi-to-vpto` local lowering。 diff --git a/docs/designs/vmi-lane-stride-generalization-design.md b/docs/designs/vmi-lane-stride-generalization-design.md new file mode 100644 index 0000000000..e5134bede3 --- /dev/null +++ b/docs/designs/vmi-lane-stride-generalization-design.md @@ -0,0 +1,903 @@ +# VMI Lane-Stride Layout Generalization Design + +本文定义 `lane_stride` 从 group-slot 专用属性泛化为 VMI layout 的通用 +物理 lane 映射轴。目标不是只优化 `64xf16 -> 64xf32`,而是给 dense +value、group-slot value、类型转换、broadcast materialization 和 load/store +rematerialization 提供统一表达。 + +## 1. Problem + +当前文档对 `lane_stride` 的语义是: + +```text +logical lane-sized physical slot 之间有固定间距 +``` + +但实现只允许它出现在: + +```text +#pto.vmi.layout +``` + +并且现有 helper 会把 `ui8 lane_stride=4` 这类 group-slot lowering 映射为 +b32 carrier。这导致两个问题: + +1. dense value 无法表达“64 个 f16 logical lanes 放在一个 128xf16 物理向量 + 的偶数 lane 上”。 +2. `lane_stride` 的 layout 语义和 group-slot carrier lowering 被混在一起。 + +泛化后必须保持以下边界: + +```text +lane_stride: + layout lane map, does not change logical element type + +carrier packing: + one lowering strategy for selected group-slot integer stores +``` + +## 2. Semantic Model + +### 2.1 Dense Layout + +Dense layout 仍然表示每个 logical lane 都有语义值。第一阶段只增加 +`lane_stride` 一个新轴: + +```text +deinterleave factor F +block elems B +lane stride LS +``` + +建议 surface spelling: + +```text +#pto.vmi.layout +#pto.vmi.layout + +#pto.vmi.layout +#pto.vmi.layout +``` + +Defaults: + +```text +F = 1 for contiguous +B = 1 +LS = 1 +``` + +Dense lane map: + +```text +logical lane i + +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +dense lane index in part = t * B + r +physical part p, physical lane dense lane index * LS +``` + +The current stage intentionally describes only phase-zero strided dense layouts. +For `lane_stride = 2`, that means semantic lanes occupy even physical lanes. + +An optional future `lane_offset` or `lane_phase` field is useful only after the +IR has a concrete zero-copy view or producer whose logical lane `i` is +intentionally represented at physical lane `2 * i + 1` or another non-zero +phase. The current stage has no such producer. The field should +not be added just because the target has a `vcvt ODD` instruction. + +`vcvt ODD` is needed in two different situations: + +```text +1. Full conversion of a packed contiguous source. + Example: contiguous f16 -> deinterleaved=2 f32 uses EVEN and ODD. + This is not an odd-phase dense source layout; it is the normal multi-part + lowering of a packed source. + +2. Single-part conversion of a future zero-copy odd-lane view. + Example: if a logical deinterleave/extract result were represented as + f16 lane_stride=2, lane_offset=1 instead of being compacted, then converting + that view to f32 contiguous would use ODD. This requires an explicit VMI + producer or consumer contract; current-stage dense stride does not + create such values. +``` + +The current design implements case 1 with existing conversion lowering and case +2 only as a non-goal extension. The useful dense-stride optimization in this +stage uses phase-zero layout and therefore selects `EVEN` for `W=2`. + +### 2.2 Deinterleaved vs Lane Stride + +Use `deinterleaved` when multiple semantic residue classes or physical parts of +the same dense logical value are all present. + +Use dense `lane_stride` when one semantic stream is stored sparsely inside each +physical part and the skipped lanes have no semantic value for this VMI value. + +Decision rule: + +```text +all residue classes are semantic: + use deinterleaved + +only one phase-zero residue class is semantic: + use lane_stride + +multiple parts are semantic and each part is internally strided: + use deinterleaved + lane_stride +``` + +Examples: + +```text +contiguous f16 -> f32 full dense widen: + source lanes 0,1,2,3,... are all semantic + result naturally has even/odd conversion parts + use result deinterleaved=2 + +64xf16 -> 64xf32 where the f32 consumer wants contiguous: + the vcvt layout support may request source lane_stride=2 + if the source producer/rematerialization can satisfy that request, source + lanes 0,2,4,... become semantic and lanes 1,3,5,... are holes for this value + extf result can then be contiguous through one EVEN conversion + +group-reduce or dense consumer that needs two/four logical fragments: + the fragments are semantic parts of the same dense value + use deinterleaved=2/4, not lane_stride +``` + +Do not use `lane_stride` to describe a full packed value that happens to need an +ODD conversion part. Do not use `deinterleaved` to describe holes inside one +physical part. + +Important distinction: + +```text +one hardware vcvt output: + always one contiguous VPTO output register + +VMI ext result layout: + describes how one or more hardware output registers map back to logical lane + order +``` + +For `W=2`, with logical f16 lanes named by their logical indices: + +```text +source contiguous: + physical lanes: 0, 1, 2, 3, 4, 5, ... + vcvt EVEN output carries logical lanes 0, 2, 4, ... + vcvt ODD output carries logical lanes 1, 3, 5, ... + VMI result layout is deinterleaved=2 unless another materialization + interleaves the two outputs. + +source lane_stride=2: + physical lanes: 0, _, 1, _, 2, _, ... + vcvt EVEN output carries logical lanes 0, 1, 2, ... + VMI result layout is contiguous. +``` + +So "vcvt output is contiguous" does not by itself mean the VMI `extf` result is +contiguous. The result layout depends on the logical lane mapping of the source +layout and the selected conversion parts. + +### 2.3 Group-Slot Layout + +Group-slot layout remains non-dense. Only `G` group result slots have semantic +values: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Existing mapping is preserved: + +```text +slot_block(g) = g / K +slot_lane(g) = (g % K) * LS +``` + +This remains a group-slot placement property. It does not make non-slot lanes +semantic. Existing `ui8 lane_stride=4` to b32 carrier lowering is still legal, +but it is not the definition of `lane_stride`. + +Group-slot `lane_offset` is not needed in the current stage. It should remain +out of scope unless a real group-slot producer needs non-zero phase. + +## 3. Physical Capacity + +`lane_stride` increases the number of physical lane slots needed by a dense +part, but it does not change the VMI logical element type. + +For one dense physical part in the current stage: + +```text +logical lanes in this part = M +required physical lanes = (M - 1) * LS + 1 +``` + +The number of VPTO physical registers for each part is: + +```text +ceil(required physical lanes / lanes_per_vpto_register(T)) +``` + +Total physical arity: + +```text +deinterleave factor F * registers per part +``` + +Example: + +```text +!vmi.vreg<64xf16, contiguous, lane_stride=2> + +lanes_per_vpto_register(f16) = 128 +required physical lanes = 63 * 2 + 1 = 127 +physical arity = 1 +``` + +The 64 logical f16 lanes occupy physical f16 lanes `0, 2, 4, ... 126` of one +`!pto.vreg<128xf16>`. The other lanes are undefined unless another layout +value gives them semantics. + +Some lowerings represent the same lane map with wider carrier slots instead of +logical-element lanes. For example, a b16 value with `lane_stride=2` may be +lowered as the low b16 element of each b32 carrier slot when using +`UNPK_B16`/`PK_B32` or register pack/unpack materialization. This does not +change the VMI logical element type; it is a VPTO lowering representation choice. + +## 4. Type And Operation Generalization + +The design is element-type agnostic. Dense `lane_stride` applies to any VMI +element type whose physical VPTO lane count is known: + +```text +f8, f16, bf16, f32 +i8, ui8, i16, ui16, i32, ui32 +pred masks at an explicit predicate granularity +``` + +An op may support a strided dense layout only when its VPTO lowering can +preserve the lane map. Unsupported combinations are rejected by layout support +queries, not silently repaired in `vmi-to-vpto`. + +### 4.1 VPTO Pack/Unpack Support Boundary + +Dense `lane_stride` is not a generic VPTO load/store operand. It is supported +only when the lane map matches a concrete VPTO distribution or register +materializer. + +Direct compact memory support: + +| Dense lane_stride | compact load | compact store | +|---:|---|---| +| 2, b8 | `vlds UNPK_B8` | `vsts PK_B16` | +| 2, b16 | `vlds UNPK_B16` | `vsts PK_B32` | +| 2, b32 | `vlds UNPK_B32` | `vsts PK_B64` | +| 4, b8 | `vlds UNPK4` | `vsts PK4_B32` | +| 4, b16/b32 | no direct dist | no direct dist | + +Direct scalar broadcast load target capability: + +```text +lane_stride=2/4, b8/b16/b32: + vlds BRC_B8/B16/B32 +``` + +The current stage does not add a VMI scalar broadcast-load op. BRC is therefore +a target capability for a separate scalar broadcast-load semantic, not part of +the current `vmi.load -> ensure_layout` compact-stream fold. + +Register fallback between contiguous and dense `lane_stride` should use the +register-side counterpart of these distributions: + +```text +contiguous -> lane_stride: + vsunpack/vzunpack-style placement into wider slots + +lane_stride -> contiguous: + vpack-style extraction from wider slots +``` + +`vintlv`/`vdintlv` remain the materializers for two-stream +interleave/deinterleave layouts; they are not the primary fallback for dense +`lane_stride`. + +### 4.2 Layout-Transparent Dense Ops + +Layout-transparent dense ops include ordinary elementwise arithmetic and +select-like ops when every dense data operand/result has the same layout: + +```text +add/mul/fma/min/max/select: + operands and result require identical dense layout key + key includes F, B, and LS +``` + +No physical shuffle is implied by these ops. + +### 4.3 Widening Conversion + +Let a widening conversion increase element storage width by ratio `W`: + +```text +f16 -> f32: W = 2 +bf16 -> f32: W = 2 +i16 -> i32: W = 2 +ui16 -> ui32: W = 2 +f8 -> f32: W = 4 +i8 -> i32: W = 4 +ui8 -> ui32: W = 4 +ui8 -> ui16: W = 2 +``` + +For a phase-zero source dense layout with `lane_stride = LS`, a single hardware +conversion part is sufficient when: + +```text +LS % W == 0 +``` + +The selected hardware part in the current stage is: + +```text +part = 0 +``` + +The result layout after conversion is: + +```text +result lane_stride = LS / W +``` + +For a future phase-aware layout with `lane_offset = O`, the generic relation is: + +```text +part = O % W +result lane_stride = LS / W +result lane_offset = (O - part) / W +``` + +That future relation should be enabled only when a real odd/non-zero-phase VMI +producer or consumer exists. + +Examples: + +```text +f16 source: contiguous, lane_stride=2 +extf to f32: + use vcvt EVEN + result contiguous + +f16 source: contiguous, lane_stride=4 +extf to f32: + use vcvt EVEN + result contiguous, lane_stride=2 +``` + +If `LS < W` or `LS % W != 0`, the conversion may need multiple hardware parts +and may naturally produce a deinterleaved result. The current contiguous source +case is the common example: + +```text +f16 source: contiguous, lane_stride=1 +extf to f32: + use vcvt EVEN and vcvt ODD + result deinterleaved=2 +``` + +Assignment chooses one preferred fact for the op before lowering. Consumer +requests are handled by the existing use-site materialization path after the +op's assigned result layout is fixed. + +The preferred direction for this optimization is not "notice the input is +already strided". The conversion op can be the layout-entry point and compute a +single preferred layout fact for the current op instance. The choice must be +arity-driven, not special-cased by a spelling such as `64xf32`. + +For source/result logical lane count `N`, let: + +```text +natural result layout: + source dense factor F, lane_stride 1 + result dense factor F * W, lane_stride 1 + +compact result layout: + result keeps source dense factor F and uses lane_stride 1 + source uses lane_stride W inside the same dense factor F +``` + +The self-preferred widening rule is: + +```text +if physical_arity(compact result) < physical_arity(natural result) + and target supports the required source lane_stride relation: + choose compact result and request source lane_stride=W +else: + choose natural result deinterleaved by W +``` + +For ordinary contiguous `f16 -> f32` this gives: + +```text +64xf32: + compact arity = 1 + natural deinterleaved=2 arity = 2 + choose source lane_stride=2, result contiguous + +128xf32: + compact arity = 2 + natural deinterleaved=2 arity = 2 + choose natural result deinterleaved=2 + +256xf32: + compact arity = 4 + natural deinterleaved=2 arity = 4 + choose natural result deinterleaved=2 +``` + +If the source is already deinterleaved by `F`, the natural result factor is +`F * W`. For example, `deinterleaved=2 f16 -> f32` naturally produces +`deinterleaved=4 f32`. + +The same arity rule applies to other widening ratios and types. For example, +`ui8 -> ui32` has `W=4`; a lane-stride source is preferred only when the +contiguous result has fewer physical chunks than the natural +`deinterleaved=4` result and the target supports the `lane_stride=4` relation. + +The two layout facts are therefore: + +```text +baseline fact: + source contiguous + result deinterleaved=W + arity: physical_arity(result deinterleaved=W) + +lane-stride fact: + source lane_stride=W + result contiguous + arity: physical_arity(result contiguous) + source layout request: explicit +``` + +In the current single-preference framework, `ext` should publish one preferred +fact. The lane-stride fact is an op-local preference: assignment records the +required source/result relation in the IR and inserts `ensure_layout` at the +source use if the producer is not already in that layout. Later +rematerialization or fold passes may remove that helper when a concrete producer +rewrite exists; otherwise the helper is either lowered by a registered +contiguous/lane-stride materializer or rejected before `vmi-to-vpto`. + +This keeps the optimization in layout assignment/rematerialization, not in a +late `vmi-to-vpto` peephole, and stays within the existing single-preference +assignment model. + +### 4.4 Narrowing Conversion + +Narrowing uses the same arity-driven idea in the opposite direction. If source +element width is `W` times the result element width, a single hardware narrowing +part can produce a phase-zero strided result when: + +```text +result lane_stride = source lane_stride * W +part = 0 +``` + +This covers more than f32-to-f16. The same relation applies to: + +```text +f32 -> f16/bf16 +i32 -> i16/i8 +ui32 -> ui16/ui8 +ui16 -> ui8 +``` + +The natural narrowing relation is the inverse of natural widening: + +```text +source dense factor F * W, lane_stride 1 +result dense factor F, lane_stride 1 +``` + +The compact-store-oriented relation is: + +```text +source keeps dense factor F and lane_stride 1 +result keeps dense factor F and uses lane_stride W +``` + +Narrowing has the same candidate family as widening. The arity comparison is +made on the source side, because the compact relation keeps the source +contiguous while the natural relation may require a deinterleaved source. + +The self-preferred narrowing rule is: + +```text +if physical_arity(compact contiguous source) + < physical_arity(natural deinterleaved source) + and physical_arity(compact source) == physical_arity(strided result) + and target supports the source-contiguous/result-lane_stride relation: + choose source contiguous, result lane_stride=W +else: + choose natural deinterleaved-source to contiguous-result relation +``` + +Use-site requests may still select the strided relation when a later consumer +can directly consume it: + +```text +if a consumer requests result lane_stride=W + and target supports source-contiguous/result-lane_stride narrowing: + request source contiguous + set or rematerialize result lane_stride=W +``` + +For ordinary `f32 -> f16`: + +```text +64xf32 -> 64xf16: + natural source deinterleaved=2 arity = 2 + compact source contiguous arity = 1 + choose source contiguous, result lane_stride=2 + +128xf32 -> 128xf16: + natural source deinterleaved=2 arity = 2 + compact source contiguous arity = 2 + choose natural source deinterleaved=2, result contiguous +``` + +So trunc should not blindly create a lane-stride result for every narrowing. +It should apply the same arity/support checks as ext. A consumer may still +request a strided result when that layout is useful, such as an unmasked compact +store lowered with `PK`/`PK4`. For masked stores, the value and mask must share +the same lane map before a direct packed masked store is legal. + +The exact supported parts are target-op dependent. The layout assignment layer +should ask the op support interface whether a given source/result layout pair is +legal, rather than encoding type-specific shortcuts. + +### 4.5 Broadcast Materialization + +Broadcast remains a logical operation. `lane_stride` only describes the chosen +materialized layout. + +Scalar or group broadcast can materialize to a dense layout only when the +broadcast lowering or rematerialization support query accepts that lane map: + +```text +logical broadcast: + lane i gets value group(i) + +materialized layout: + lane i is stored at physical lane map(i) +``` + +This keeps E2B-style optimizations in the layout/rematerialization layer. A +group broadcast load may choose a dense strided layout when that layout directly +matches a consumer or a target instruction. If another consumer needs a +different layout, rematerialization may clone the broadcast or insert +`ensure_layout`. + +`group_broadcast_load` is also a VMI semantic, not an E2B semantic. It means: + +```text +for each logical group g: + load one scalar from source[offset + g * source_group_stride] + broadcast that scalar to all lanes in group g +``` + +E2B is a target lowering choice for the subset where that logical memory pattern, +the group size, the element width, and the assigned result layout match the E2B +packet semantics. Other lowering strategies may implement the same VMI +operation, so support queries should report "E2B is applicable" instead of +rewriting the VMI meaning to "this op is E2B". + +### 4.6 Masked Lane-Stride Stores + +Masks are logical predicates. A `masked_store` mask bit denotes whether a +logical element participates in the store; it is not automatically a predicate +for the physical lane slot that happens to carry that element after layout +assignment. + +For dense `lane_stride`, this distinction matters. With `lane_stride=2`, +logical lane `i` is carried in physical lane `2*i`. A packed store then +compacts those even physical lanes into a contiguous memory stream. A user mask +that is still contiguous cannot be passed directly to that packed store, because +the packed-store predicate is interpreted after the value lanes have been +compacted. + +A direct masked compact store is therefore legal only when the compiler has +assigned the value and mask the same lane map. That may happen because the mask +producer can directly produce the requested lane map, because assignment inserts +a mask `ensure_layout`, or because rematerialization rebuilds the mask producer +for that lane map. Without that compiler-derived proof, assignment should keep +a layout that the existing masked-store path can lower, even if the corresponding +unmasked store could use a dense lane-stride `PK` instruction. + +## 5. Assignment And Optimization Boundary + +The assignment pipeline should keep the existing responsibility split: + +```text +layout assignment: + collect consumer requests + ask producer/op support + assign explicit layout attrs + insert ensure_layout for use-local conflicts + +rematerialization: + clone cheap producers for incompatible use-site layouts + replace ensure_layout(producer) when producer can directly create target layout + +layout fold: + erase or fuse materialization helpers when the producer already has the + requested lane map + +vmi-to-vpto: + lower explicit assigned layouts only + no hidden layout selection policy +``` + +Dense `lane_stride` is therefore an assigned layout fact, not a lowering-side +pattern. An entry op such as `extf` may prefer it from the conversion ratio +alone; producer-specific rewrites are handled later by fold/rematerialization +passes over explicit helpers. The selected layout is fixed before +`vmi-to-vpto`, and `vmi-to-vpto` does not rediscover the preference. + +## 6. End-To-End Case Walkthroughs + +These cases are the intended test for the design. They show when dense +`lane_stride` is useful and when it should lose to the existing deinterleaved +plan. + +The logical programs in this section are pre-assignment VMI and do not carry +concrete layouts. Layouts shown under "baseline plan" or "lane-stride plan" are +possible assignment results, not layouts written in the input program. + +### 6.1 Contiguous Load, Ext, Contiguous Store + +Logical program: + +```text +%x16 = vmi.load %in : 64xf16 +%x32 = vmi.extf %x16 : 64xf16 -> 64xf32 +vmi.store %x32, %out : dense contiguous memory effect +``` + +Baseline plan: + +```text +load result: + contiguous f16 + +ext relation: + source contiguous f16 + result deinterleaved=2 f32 + lower: vcvt EVEN + vcvt ODD + +store: + needs contiguous f32 + requires result materialization deinterleaved=2 -> contiguous +``` + +Lane-stride plan: + +```text +load result: + lane_stride=2 f16 + +ext relation: + source lane_stride=2 f16 + result contiguous f32 + lower: vcvt EVEN + +store: + consumes contiguous f32 directly +``` + +Assignment chooses the lane-stride plan for this shape because the contiguous +`64xf32` result uses one physical chunk while the natural deinterleaved result +uses two physical chunks. This decision is made by the cast arity rule, not by +a pattern that names `64xf32` directly. + +The load side then has two concrete outcomes: + +```text +accepted direct load fold: + the original load has only the lane-stride use + compact load semantics match a supported UNPK dist + vmi-layout-fold changes the VMI load result layout in place + +no direct load fold: + keep the explicit source ensure_layout + lower it through register pack/unpack if that materialization is supported + otherwise validation rejects the unsupported assigned relation +``` + +This case proves that `extf` can be the layout-entry point, while `load` support +is still decided by the load/ensure fold or by the explicit materialization +helper. + +### 6.2 Broadcast, Ext, Contiguous Store + +Logical program: + +```text +%b16 = vmi.broadcast %s : 1xf16 -> 64xf16 +%b32 = vmi.extf %b16 : 64xf16 -> 64xf32 +vmi.store %b32, %out +``` + +Baseline plan: + +```text +broadcast materializes contiguous f16 +ext produces deinterleaved=2 f32 through EVEN + ODD +store materializes deinterleaved=2 -> contiguous +``` + +Lane-stride plan: + +```text +broadcast rematerializes directly as lane_stride=2 f16 +ext produces contiguous f32 through one EVEN +store consumes contiguous f32 +``` + +Here the lane-stride plan is accepted because broadcast is a rematerializable +producer: it can be rebuilt with the requested physical lane map instead of +requiring a register layout conversion. This is the kind of producer where +`vcvt` should drive a source `lane_stride=2` request. + +### 6.3 Ext Feeding A Deinterleaved Consumer + +Logical program: + +```text +%x16 = producer : 128xf16 +%x32 = vmi.extf %x16 : 128xf16 -> 128xf32 +%r = vmi.group_reduce %x32 // requests deinterleaved=2 +``` + +Baseline plan: + +```text +source contiguous f16 +result deinterleaved=2 f32 +consumer consumes result directly +``` + +Lane-stride plan: + +```text +source lane_stride=2 f16 +result contiguous f32 +consumer then needs contiguous -> deinterleaved=2 materialization +``` + +The baseline plan should win. A lane-stride fact is not useful when it creates a +layout the consumer does not want. The cast arity rule also does not prefer +lane_stride here: `128xf32` contiguous and `128xf32 deinterleaved=2` both use +two physical chunks. + +### 6.4 One Ext Result Feeding Store And Reduce + +Logical program: + +```text +%x16 = cheap_or_expensive_producer : 128xf16 +%x32 = vmi.extf %x16 : 128xf16 -> 128xf32 +vmi.store %x32, %out // requests contiguous +vmi.group_reduce %x32 // requests deinterleaved=2 +``` + +If `%x16` is not cheap to rematerialize: + +```text +assign ext result deinterleaved=2 for the reduce +insert ensure_layout at the store use +``` + +If `%x16` and `extf` are cheap to rematerialize: + +```text +shared path: + source contiguous -> ext result deinterleaved=2 -> reduce + +store-only remat path: + rematerialized source lane_stride=2 -> ext result contiguous -> store +``` + +This is a rematerialization decision, not a local `vcvt` peephole. + +### 6.5 Group Broadcast Load Feeding Ext + +Logical program: + +```text +%g16 = vmi.group_broadcast_load %scale : logical dense 64xf16 +%g32 = vmi.extf %g16 : 64xf16 -> 64xf32 +consumer requests contiguous %g32 +``` + +The lane-stride plan is accepted only if the group broadcast load lowering can +emit the requested lane map directly: + +```text +group broadcast load result lane_stride=2 f16 +ext result contiguous f32 +``` + +If the broadcast load can only produce contiguous or deinterleaved packets for +the target element width, assignment should keep those layouts and let later +materialization/rematerialization handle the consumer conflict. Dense +`lane_stride` is a requestable layout, not a guarantee that every producer can +create it. + +## 7. Compatibility Rules + +Two dense layouts are identical only if all lane-map fields match: + +```text +F, B, LS +``` + +Two dense layouts may be related by an explicit materialization only if a +registered relation can lower the map conversion. Examples: + +```text +contiguous <-> deinterleaved=2 +deinterleaved=2 <-> deinterleaved=4 when supported by existing intlv/dintlv +contiguous <-> contiguous, lane_stride=2 when pack/unpack materialization or +producer rematerialization supports it +``` + +The baseline assignment must not assume an arbitrary dense-to-dense +`ensure_layout` is free or legal. Unsupported materializations should fail in +verification or remain unselected by support queries. + +## 8. Non-Goals + +This design does not: + +1. Turn memory layout into strided memory semantics. Dense VMI `lane_stride` + describes register materialization, not GM/UB address stride. +2. Make non-slot lanes of group-slot layouts semantic. +3. Require every VPTO op to support every strided layout. +4. Encode `64xf16 -> 64xf32` as a one-off `vcvt EVEN` peephole. + +## 9. First Useful Optimization + +The motivating case becomes one instance of the generic rule: + +```text +source: + requested as !vmi.vreg<64xf16, contiguous, lane_stride=2> + +op: + extf f16 -> f32, W=2 + +result: + !vmi.vreg<64xf32, contiguous> + +lowering: + one vcvt EVEN +``` + +The same mechanism also covers: + +```text +bf16 -> f32 with phase-zero lane_stride=2 +ui8 -> ui16 with lane_stride=2 +ui8 -> ui32 with lane_stride=4 +f8 -> f32 with lane_stride=4 +narrowing conversions that intentionally produce phase-zero strided results +broadcast materialization into a consumer-required strided dense layout +``` diff --git a/docs/designs/vmi-lane-stride-generalization-implementation.md b/docs/designs/vmi-lane-stride-generalization-implementation.md new file mode 100644 index 0000000000..4a3b251a41 --- /dev/null +++ b/docs/designs/vmi-lane-stride-generalization-implementation.md @@ -0,0 +1,1690 @@ +# VMI Lane-Stride Layout Generalization Implementation Plan + +本文给出 `lane_stride` 泛化的实现路径。设计目标是把 lane-strided dense +layout 作为一等 layout fact 固化、传播、rematerialize 和 lower,而不是在 +`vmi-to-vpto` 中识别单个 `64xf16 -> 64xf32` pattern。 + +## 1. Implementation Principles + +1. `lane_stride` is a lane-map field. +2. Dense `lane_stride` does not change the VMI logical element type. +3. Group-slot carrier packing is a separate lowering helper. +4. Layout assignment decides layout before `vmi-to-vpto`. +5. `vmi-to-vpto` only lowers explicit assigned layout attrs. + +Pre-existing baseline before this design: + +```text +dense contiguous/deinterleaved layouts: + did not carry lane_stride + +regular VMI load/store: + did not support dense lane_stride + support contiguous and selected deinterleaved lowering/materialization paths + +VPTO load/store: + pto.vlds/pto.vsts have a dist string and the VPTO surface supports several + distribution families, but there is no generic lane_stride operand + +group-slot lane_stride: + already existed and was used by selected group-store packed-byte lowering +``` + +Any dense lane-stride load/store support must enter explicitly by mapping a VMI +lane-stride layout to a specific supported VPTO dist family or materialization +sequence. It must not be inferred in `vmi-to-vpto` from a one-off producer or +consumer pattern. + +Current stage status: + +| Area | Status | Notes | +|---|---|---| +| Dense layout attrs | Supported | Dense contiguous/deinterleaved layouts carry `lane_stride`; group-slot carrier layout remains separate. | +| Direct compact load/store | Supported for selected phase-zero maps | LS=2 b8/b16/b32 through `UNPK_B8/B16/B32` and `PK_B16/B32/B64`; LS=4 b8 through `UNPK4` and `PK4_B32`. | +| Load/store layout folds | Supported with one-load/one-store preservation | `load -> ensure_layout(lane_stride)` rewrites the original load layout when all uses agree; `ensure_layout(lane_stride -> contiguous) -> store` lets the VMI store consume the lane-stride value. | +| Dense widening ext | Supported | `getPreferredCastLayoutFact` chooses the arity-reducing source `lane_stride=W` / result contiguous relation when it beats the natural deinterleaved result; otherwise it keeps the natural relation. | +| Dense narrowing trunc | Supported for dense natural paths | `getPreferredCastLayoutFact` uses the same arity rule in the inverse direction, so trunc keeps the natural deinterleaved-source / contiguous-result relation unless a compact relation actually reduces arity. | +| Masked compact store | Partially supported | Legal only when value and mask have the same lane map and the mask can be compacted for the selected store dist. | +| Masked trunc tail | Not optimized yet | Keep the existing legal path until mask lane-stride assignment/materialization is available. | +| Register fallback | Partially supported | Only same-physical-arity contiguous `<->` lane_stride paths with legal pack/unpack carriers. Arity-changing fallback is not in scope for this stage. | +| Group broadcast load | Supported only through specific strategies | `group_broadcast_load` remains a VMI semantic; E2B is one strategy with exact shape/layout constraints. | + +Remaining design/implementation work from this discussion is intentionally +limited to these areas: + +| Area | Work to settle | Required proof before enabling | +|---|---|---| +| Cast assignment | Keep `getPreferredCastLayoutFact` as the single op-local preferred relation helper, but make it shape-aware: compute the natural relation, compute the compact lane-stride relation, and select compact only when physical arity improves. | `64xf16 -> 64xf32` chooses source `lane_stride=2` and result contiguous; `128/256xf16 -> f32` keep natural `deinterleaved=2`; dense trunc keeps the natural relation unless compact arity wins. | +| Masked store | Let `masked_store` request the same lane map for value and mask, or keep the existing legal path when the mask cannot be assigned/rematerialized into that lane map. | No path may lower a lane-stride value with a stale contiguous user mask; lowering must compact the assigned mask into the packed-store predicate. | +| Group broadcast load | Keep `group_broadcast_load` as a VMI logical operation and make E2B only one support/lowering strategy selected by shape, element width, stride, and assigned result layout. | A failed E2B match must mean "this lowering strategy is unavailable", not "the VMI op is invalid" unless no fallback strategy is registered. | + +Known support boundaries that are not part of this discussion's remaining-work +queue: + +```text +b32 contiguous <-> lane_stride register fallback through generic vpack/vunpack +generic scalar broadcast-load VMI semantic for BRC +dense lane-stride masked_load +arity-changing register fallback +LS=4 b16/b32 direct compact load/store +LS > 4 direct compact load/store +non-zero lane_offset / lane_phase +ordinary load cloning/rematerialization without safe-read proof +global cost search across conflicting consumer layouts +partial-chunk dense lane-stride direct memory beyond the current full-chunk gate +``` + +### 1.1 VPTO Dist Capability Boundary + +VPTO already exposes fixed distribution families that can implement specific +layout-producing or layout-consuming memory operations: + +```text +vlds: + NORM + BRC_B8/B16/B32 + US_B8/B16 + DS_B8/B16 + UNPK_B8/B16/B32 + BRC_BLK + E2B_B16/B32 + UNPK4 + SPLT4CHN + SPLT2CHN_B8/B16 + +vldsx2: + BDINTLV + DINTLV_B8/B16/B32 + +vsts: + NORM_B8/B16/B32 + 1PT_B8/B16/B32 + PK_B16/B32/B64 + PK4_B32 + MRG4CHN_B8 + MRG2CHN_B8/B16 + +vstsx2: + INTLV_B8/B16/B32 +``` + +These are not equivalent to an arbitrary dense `lane_stride` operand: + +```text +DINTLV/INTLV: + two-stream deinterleave/interleave memory operation + maps naturally to VMI deinterleaved layouts, not to one sparse semantic stream + +US/DS: + fixed 2x upsample/downsample load families for b8/b16 + can serve selected lane-map producers when the semantic mapping matches exactly + +UNPK/PK/PK4: + fixed slot-pack/slot-unpack memory families + directly express selected dense lane_stride layouts such as b16 LS=2 and + b8 LS=4, but not arbitrary LS=N + +BRC/E2B/BRC_BLK: + fixed broadcast or group-expansion load families + useful when logical broadcast plus assigned layout matches the family + +MRG/SPLT: + fixed channel merge/split families + useful only for matching channel layouts +``` + +So VPTO has enough surface area to support selected dense lane-stride memory +optimizations, but VMI must model them as explicit support cases: + +```text +VMI layout fact + op semantics + element width + -> exact VPTO dist family + or materialization/rematerialization sequence + or unsupported +``` + +Concrete support matrix for dense phase-zero `lane_stride`: + +| Dense lane_stride | Compact stream load -> dense LS | Single-scalar broadcast load -> dense LS | Dense LS -> compact stream store | +|---:|---|---|---| +| 2 | direct for b8/b16/b32 through `vlds UNPK_B8/B16/B32` | target dist exists as `vlds BRC_B8/B16/B32`; needs a separate single-scalar broadcast-load VMI semantic | direct for b8 through `vsts PK_B16`, b16 through `vsts PK_B32`, and b32 through `vsts PK_B64` | +| 4 | direct for b8 through `vlds UNPK4` | target dist exists as `vlds BRC_B8/B16/B32`; needs a separate single-scalar broadcast-load VMI semantic | direct for b8 through `vsts PK4_B32` | + +| VMI memory semantic | Element width | VPTO op/dist | VMI result layout | Direct dense `lane_stride` support | +|---|---:|---|---|---| +| load one scalar and every logical lane uses it | b8 | `vlds BRC_B8` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load one scalar and every logical lane uses it | b16 | `vlds BRC_B16` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load one scalar and every logical lane uses it | b32 | `vlds BRC_B32` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load compact stream `x[i]` into semantic lane `2*i` | b8 | `vlds UNPK_B8` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `2*i` | b16 | `vlds UNPK_B16` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `2*i` | b32 | `vlds UNPK_B32` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `4*i` | b8 | `vlds UNPK4` | `contiguous, lane_stride=4` | yes | +| load compact stream `x[i]` into semantic lane `4*i` | b16/b32 | none | `contiguous, lane_stride=4` | no direct VPTO dist | +| load compact stream `x[i]` into semantic lane `K*i`, `K > 4` | any | none | `contiguous, lane_stride=K` | no direct VPTO dist | +| load memory `x[2*i]` into logical lane `i` | b8 | `vlds DS_B8` | contiguous | no; this is memory downsample | +| load memory `x[2*i]` into logical lane `i` | b16 | `vlds DS_B16` | contiguous | no; this is memory downsample | +| load alternating memory stream into even/odd logical streams | b8 | `vldsx2 DINTLV_B8` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| load alternating memory stream into even/odd logical streams | b16 | `vldsx2 DINTLV_B16` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| load alternating memory stream into even/odd logical streams | b32 | `vldsx2 DINTLV_B32` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store semantic lane `2*i` as compact memory `x[i]` | b8 | `vsts PK_B16` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `2*i` as compact memory `x[i]` | b16 | `vsts PK_B32` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `2*i` as compact memory `x[i]` | b32 | `vsts PK_B64` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `4*i` as compact memory `x[i]` | b8 | `vsts PK4_B32` | source `contiguous, lane_stride=4` | yes | +| store semantic lane `4*i` as compact memory `x[i]` | b16/b32 | none | source `contiguous, lane_stride=4` | no direct VPTO dist | +| store semantic lane `K*i` as compact memory `x[i]`, `K > 4` | any | none | source `contiguous, lane_stride=K` | no direct VPTO dist | +| store two compact streams as alternating memory | b8 | `vstsx2 INTLV_B8` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store two compact streams as alternating memory | b16 | `vstsx2 INTLV_B16` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store two compact streams as alternating memory | b32 | `vstsx2 INTLV_B32` | two compact streams or deinterleaved=2 | no; not one sparse stream | + +Masked compact stores have an extra legality rule. The `vmi.masked_store` +predicate is a logical-lane predicate, while a VPTO packed store consumes a +predicate in the compacted store coordinate after the sparse lanes have been +packed. Therefore a lane-stride value cannot be paired with an unrelated +contiguous mask and lowered directly to `PK`/`PK4`. + +The direct masked-store path is legal only when all of these hold: + +```text +value source layout == mask source layout +value/mask physical arity matches +mask granularity matches the logical value element width before compaction +target has a predicate compaction path for the packed-store dist +``` + +For example, an f16 value with `lane_stride=2` places logical lanes in even +physical lanes. If a user mask remains contiguous, mask bit `i` still denotes +logical lane `i`, not physical lane `2*i`. Passing that mask directly to +`vsts PK_B32` would gate the wrong compact positions for tail or sparse masks. +The current legal path requires the mask to carry the same lane map as the value +and then compacts it with predicate unpack operations before emitting the +packed store. Ordinary unmasked `vmi.store` is different: lowering creates the +compact prefix predicate itself, so there is no user mask to reinterpret. + +Until masked-store assignment can request and prove the same lane map for value +and mask, assignment must keep masked-tail narrowing on an existing legal path +instead of choosing a lane-stride trunc result solely because the store could +otherwise use `PK`. + +Current-stage implementation: + +```text +lib/PTO/Transforms/VMILayoutAssignment.cpp + +VMIMaskedStoreOp keeps the existing conservative request: + requestDataUse(value, contiguous) + requestMaskUse(mask, contiguous, elementGranularity) + +trunc assignment does not inspect masked_store users and does not preserve a +special masked-store guard. It records the source/result relation returned by +getPreferredCastLayoutFact. If that conflicts with a masked_store contiguous +request, normal assignment conflict handling inserts the required +ensure_layout. +``` + +Future lane-stride `masked_store` support must be added as an explicit +consumer-owned extension, not as a trunc special case. The future dataflow must +prove that value and mask share the same lane map before a packed masked store +is legal: + +```text +%n = vmi.trunc* %wide + : source contiguous -> result contiguous, lane_stride = W + +%m_ls = vmi.ensure_mask_layout %m + : mask contiguous -> mask contiguous, lane_stride = W + +vmi.masked_store %n, %dst[%off], %m_ls +``` + +That future extension would need the same local VMI proof before lowering can do +the mechanical predicate compaction: + +```text +vmi-layout-fold: + may fold ensure_layout(value) + ensure_mask_layout(mask) into masked_store + only through canFoldContiguousMaskedStoreMaterialization + +vmi-to-vpto: + sees valueLayout == maskLayout + calls createDenseLaneStrideStorePredicate + emits LOWER punpack on the mask + emits vsts PK_B16/PK_B32/PK4_B32 as selected by value element width/layout +``` + +Future negative tests should cover the fallback: + +```text +fallback: + mask cannot be assigned/materialized to the candidate lane_stride + CHECK masked_store keeps contiguous value/mask request + CHECK no PK/PK4 masked compact store is emitted with a stale contiguous mask +``` + +The remaining VPTO dist tokens are fixed non-lane-stride operations: + +```text +UNPK_B8/B16/B32: + compact load into one element per 16/32/64-bit slot, giving lane_stride=2 for + b8/b16/b32 dense values + +UNPK4: + compact load into one b8 element per 32-bit slot, giving lane_stride=4 for b8 + +PK_B16/B32/B64 and PK4_B32: + compact store from one active low element per 16/32/64-bit slot. PK_B32 is + exactly the direct compact store for a b16 value with lane_stride=2, and + PK4_B32 is exactly the direct compact store for a b8 value with lane_stride=4 + +MRG4CHN_B8 and MRG2CHN_B8/B16: + fixed channel merge stores, not generic lane_stride stores + +SPLT4CHN and SPLT2CHN_B8/B16: + fixed channel split loads, not generic lane_stride loads + +BRC_BLK and E2B_B16/B32: + usable only after their exact block/group expansion semantic is modeled as a + VMI broadcast producer; do not count them as generic dense lane_stride load +``` + +### 1.2 Contiguous/Lane-Stride Fallback Materialization + +Direct load/store support is preferred. When a value already lives in VPTO +registers and a consumer requires the other layout, `ensure_layout` provides the +fallback conversion between contiguous and dense phase-zero `lane_stride`. + +For `contiguous -> lane_stride`, use register unpack placement when the VPTO +surface supports the required carrier type: + +```text +LS=2: + use vzunpack/vsunpack-style widening placement + b8 contiguous -> b16 slots with low b8 semantic + b16 contiguous -> b32 slots with low b16 semantic + b32 contiguous -> b64 slots with low b32 semantic + +LS=4: + for b8, apply two LS=2 unpack placements: + b8 contiguous -> b16 slots -> b32 slots with low b8 semantic +``` + +For `lane_stride -> contiguous`, use register pack when the VPTO surface +supports the required carrier type: + +```text +LS=2: + use vpack-style narrowing placement + low b8 from each b16 slot -> b8 contiguous + low b16 from each b32 slot -> b16 contiguous + low b32 from each b64 slot -> b32 contiguous + +LS=4: + for b8, apply two LS=2 pack placements: + low b8 from each b32 slot -> b16 slots -> b8 contiguous +``` + +This is the register-side counterpart of `UNPK`/`PK` memory distributions. Do +not use `vintlv`/`vdintlv` as the primary fallback for dense `lane_stride`; those +belong to two-stream interleave/deinterleave layouts. + +Current checked-in VPTO coverage: + +```text +register pack: + vpack supports integer 32 -> u16 and integer 16 -> u8 + so b16 LS=2 -> contiguous and b8 LS=2/4 -> contiguous are directly covered + when the VMI source/result physical arity is the same + b32 LS=2 -> contiguous needs 64 -> 32 pack support or another materializer + +register unpack: + vsunpack/vzunpack support integer widening by 2x + so integer b8/b16 contiguous -> LS=2 and b8 contiguous -> LS=4 are covered + when the VMI source/result physical arity is the same + +floating-point lane_stride: + b8/b16 FloatType values use bit-preserving vbitcast to unsigned integer + carriers around the same pack/unpack sequence; non-FloatType low precision + types need a VPTO vbitcast contract before enabling this fallback + +arity-changing lane_stride materialization: + contiguous -> lane_stride can be expressed as multiple unpack parts, and + lane_stride -> contiguous needs an explicit multi-part merge/pack plan. + The current stage rejects those helpers instead of guessing a cross + physical-chunk materialization. +``` + +This fallback is a materialization cost, not a layout preference. Assignment may +insert the `ensure_layout`; later folding/rematerialization should remove it when +the producer or consumer has direct support: + +```text +load -> ensure_layout(lane_stride) + fold into a VMI load whose result has the requested lane_stride; vmi-to-vpto + later lowers that load to UNPK when the element width and stride match. + BRC remains the target dist for a separate scalar broadcast-load VMI semantic. + +ensure_layout(lane_stride) -> store + fold into a VMI store that directly consumes the lane_stride value; vmi-to-vpto + later lowers that store to PK/PK4 when the element width and stride match + +ordinary producer -> ensure_layout(contiguous <-> lane_stride) + lower to register pack/unpack materialization when the element width is + supported +``` + +### 1.3 Pass Responsibilities + +Dense `lane_stride` should use the existing helper-driven layout pipeline. Do +not add a separate global candidate solver for the current stage. + +```text +pto-validate-vmi-ir: + verify surface syntax before assignment + reject malformed dense lane_stride attrs once the parser accepts them + keep lane_offset unavailable in the public attr + +vmi-layout-assignment: + assign explicit dense layouts, including lane_stride, on VMI value types + use op support queries to choose local cast relations: + widening compares natural deinterleaved result arity with compact + contiguous result arity; when compact wins, request source lane_stride=W + and set result contiguous + narrowing supports the inverse relation; when arity or a supported consumer + request chooses a strided result, request source contiguous and set result + lane_stride=W + keep unsupported or conflicting uses legal by inserting ensure_layout + serialize all decisions as type attrs or helper ops + do not clone producers, fold memory ops, or solve a global cost problem + +canonicalize/cse: + remove dead helpers and merge identical rematerialized values when normal MLIR + canonicalization can prove equivalence + no lane_stride-specific decision logic + +vmi-layout-rematerialize: + consume producer -> ensure_layout shapes + clone/rematerialize cheap producers directly in the requested lane_stride + layout when the producer support query says it can create that layout + examples: scalar broadcast, splat constants, iota, layout-transparent chains, + widening ext, and supported mask producers + do not rematerialize ordinary loads unless the load form has an explicit + safe-read proof and direct UNPK lowering support + +vmi-layout-fold: + consume helper-adjacent producer/consumer shapes + fold ensure_layout(lane_stride) feeding store into a VMI store that directly + consumes the lane_stride value when the support table has a direct compact + store lowering; this is still a VMI store, not a VPTO PK op + fold load -> ensure_layout when the load can directly produce the requested + lane map with UNPK and the rewrite preserves one load at the original + program point + fold identity lane-map conversions + leave unsupported conversions as explicit ensure_layout for validation or + vmi-to-vpto materialization + +vmi-layout-sink-materialization: + move ensure_layout across pure layout-transparent ops when all operands/results + can keep one identical dense lane map + reduce duplicated contiguous <-> lane_stride materializations + do not sink through cast, load, store, reduce, group_broadcast, or control flow + +pto-validate-vmi-layout-ir: + verify every dense value has a supported layout attr + verify ensure_layout has a supported materialization path: + identity + contiguous <-> lane_stride through register pack/unpack when supported + existing contiguous <-> deinterleaved relations + verify direct layout-aware load/store choices: + LS=2 b8/b16/b32 through UNPK/PK + LS=4 b8 through UNPK4/PK4 + BRC only after a scalar broadcast-load VMI semantic is modeled + reject unsupported direct cases such as LS=4 b16/b32 compact load/store + +vmi-to-vpto: + lower only from assigned type attrs, helper ops, and op attributes + emit direct vlds/vsts dist for UNPK/PK-supported memory cases + lower surviving contiguous <-> lane_stride ensure_layout through register + pack/unpack materialization when the VPTO verifier supports the carrier path + lower widening/narrowing casts according to the assigned source/result + lane_stride relation and concrete vcvt part + emit diagnostics instead of inventing hidden layout conversions +``` + +Implementation impact by pass/component: + +| Component or pass | Lane-stride implementation work | +|---|---| +| `VMILayoutAttr` ODS/C++ helpers | Yes. Add dense `laneStride` storage, parse/print, verifier, equality, lane-map helpers, and keep it separate from group-slot carrier packing. | +| VMI type physicalization helpers | Yes. Compute dense physical arity from `laneStride`; expose carrier-slot lowering helpers for pack/unpack paths without changing the VMI logical element type. | +| `VMILayoutSupport` / target capability helpers | Yes. Add support queries for dense `lane_stride` layouts, cast layout relations, direct UNPK/PK memory support, and contiguous `<->` lane-stride materialization support. BRC remains target capability for a separate scalar broadcast-load semantic. | +| `pto-validate-vmi-ir` | No lane-stride-specific pass algorithm. It relies on attr/op verifier updates; keep the existing surface-IR validation role. | +| `vmi-layout-assignment` | Yes. Assign dense lane-stride layouts when support queries choose them; insert `ensure_layout` for incompatible uses; serialize all decisions in types/helpers. | +| `canonicalize/cse` between VMI passes | No implementation. It remains ordinary cleanup for dead helpers and identical rematerialized producers. | +| `vmi-layout-rematerialize` | Yes. Teach producer rematerialization to create requested dense lane-stride layouts for cheap/safe producers. Do not add ordinary load remat without safe-read proof. | +| `vmi-layout-fold` | Yes. Fold `ensure_layout` into layout-aware VMI consumers, especially stores that can consume lane_stride and later lower to `PK/PK4`; fold `load -> ensure_layout` into a direct layout-aware load when it can preserve one load at the original program point; fold identity lane-map conversions. | +| `vmi-layout-sink-materialization` | Minimal generic update. It should compare dense layout keys including `laneStride` and reuse existing layout-transparent sinking; do not add cast/load/store/reduce-specific lane-stride patterns here. | +| `vmi-legalize-arith-select` | No implementation. Lane stride does not change scalar-condition select legalization. | +| `pto-validate-vmi-layout-ir` | Yes. Reject unsupported assigned layouts/helpers before lowering, including unsupported LS=4 b16/b32 compact load/store and unsupported register pack/unpack materializations. | +| `vmi-to-vpto` | Yes. Lower assigned dense lane-stride layouts, direct `UNPK/PK` memory cases, register pack/unpack `ensure_layout`, and lane-stride-aware ext/trunc lowering. | +| VPTO op verifier/emitter | Only if needed by the selected support matrix. Existing `vlds/vsts` dist tokens are already present; extending register fallback to b32 or floating-point carriers requires verifier/emitter support for the corresponding pack/unpack or bitcast form. | +| Lower VPTO/backend passes after `vmi-to-vpto` | No lane-stride-specific implementation. They see ordinary VPTO ops and existing dist tokens. | + +Any pass not listed above should not implement lane-stride-specific logic in the +current stage. New behavior must enter through the explicit layout attr, +support queries, helper ops, validation, or `vmi-to-vpto` lowering. + +Current-stage component checklist: + +This checklist records the components that participate in the current-stage +lane-stride implementation. It is not the remaining-work queue; remaining work +is limited to the masked-store and group-broadcast-load items above. + +```text +include/PTO/IR/VMIAttrs.td +lib/PTO/IR/VMI.cpp + add laneStride storage for dense contiguous/deinterleaved layouts + keep group-slot laneStride parse/print compatibility + add getContiguous(ctx, laneStride) and getDeinterleaved(..., laneStride) + split helpers into isDenseLaneStrided(), isGroupSlotLaneStrided(), + getLaneStride(), and exact dense lane-map equality helpers + update attr verifier so laneStride > 0 and lane_offset is not accepted + +lib/PTO/IR/VMI.cpp +lib/PTO/Transforms/VMIToVPTO.cpp + replace the current "hasLaneStride implies unsigned carrier widening" helper + with: + logical-element physicalization for ordinary dense VPTO values + selected carrier-slot physicalization for pack/unpack materializations + existing group-slot packed-byte carrier lowering + +include/PTO/Transforms/VMILayoutSupport.h +lib/PTO/Transforms/VMILayoutSupport.cpp + extend VMIContiguousStoreSupportKind with dense lane-stride PK/PK4 cases + extend VMILayoutMaterializationSupportKind with: + ContiguousToLaneStrideViaUnpack + LaneStrideToContiguousViaPack + LaneStrideToLaneStrideViaContiguous, only if needed + update getPreferredCastLayoutFact: + keep an internal baseline natural relation for dense widening/narrowing + compute the compact lane-stride relation from the same conversion ratio + select compact only when source/result physical arities match and the + relevant arity is strictly smaller than the baseline relation + use the returned source/result layouts for both ext and trunc assignment + update getWidenSourceLayoutForResultLayout for dense lane_stride result/source + update getContiguousStoreSupport and canFoldContiguousStoreMaterialization for + LS=2 b8/b16/b32 -> PK_B16/B32/B64 + LS=4 b8 -> PK4_B32 + update canMaterializeDataLayout for contiguous <-> dense lane_stride through + register pack/unpack when the element/carrier path is supported + +lib/PTO/Transforms/VMILayoutAssignment.cpp + teach natural/preferred layout collection to accept dense lane_stride facts + from VMILayoutSupport + keep conflict handling unchanged: insert ensure_layout at mismatched uses + do not add producer cloning, memory folding, or global cost selection here + +lib/PTO/Transforms/VMILayoutRematerialize.cpp + allow cheap producers to be cloned with dense lane_stride result types when + VMILayoutSupport says the producer can directly create that lane map + keep ordinary load/group_load/masked_load cloning blocked until a safe-read + proof is added for the specific rematerialized memory operation + +lib/PTO/Transforms/VMILayoutFold.cpp + add producer-side fold for load -> ensure_layout: + replace the load result layout with the ensure target layout when the load + has no other incompatible uses and VMILayoutSupport has direct UNPK + support + erase the ensure_layout and keep a single load at the original program point + do not clone ordinary loads in this fold + add fold for ensure_layout(lane_stride -> contiguous) feeding pto.vmi.store or + pto.vmi.masked_store into a VMI store that consumes the lane_stride source + directly; this pass does not emit or model VPTO PK. VMIToVPTO later selects + the corresponding PK/PK4 store lowering from the assigned VMI store contract + masked_store direct fold additionally requires the mask to carry the same + dense lane_stride layout and a compactable element-width granularity: + LS=2 b8/b16 and LS=4 b8 are supported through LOWER punpack mask compaction + LS=2 b32 is left as explicit materialization until b32 lane-stride mask + compaction is specified and implemented + a contiguous user mask is not enough, even if the value layout can be + compact-stored; assignment/rematerialization must first derive the same + lane map for the mask + fold exact dense lane-map identity helpers + do not fold unsupported LS=4 b16/b32 cases + +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + include laneStride in dense layout equality/support checks + reuse existing layout-transparent sinking logic + do not add lane-stride-specific sinking through casts or memory ops + +lib/PTO/Transforms/PTOValidateVMIIR.cpp + no new lane-stride algorithm + validation changes should come from attr/op verifier and VMILayoutSupport + diagnostics at the layout gate + +lib/PTO/Transforms/VMIToVPTO.cpp + update OneToN physical type conversion for dense laneStride and carrier slots + lower direct compact loads: + LS=2 b8/b16/b32 -> vlds UNPK_B8/B16/B32 + LS=4 b8 -> vlds UNPK4 + lower direct compact stores: + LS=2 b8/b16/b32 -> vsts PK_B16/B32/B64 + LS=4 b8 -> vsts PK4_B32 + lower direct compact masked_stores: + LS=2 b8/b16 -> LOWER punpack mask compaction + vsts PK_B16/B32 + LS=4 b8 -> two LOWER punpack steps + vsts PK4_B32 + LS=2 b32 -> no direct masked compact store until b32 lane-stride mask + compaction is specified and implemented + lower surviving ensure_layout contiguous <-> lane_stride through vpack and + vsunpack/vzunpack when the carrier path is legal + lower lane-stride-aware ext by selecting the concrete vcvt part from + the assigned source/result relation + lower lane-stride-aware trunc by selecting the concrete vcvt part from + the assigned source/result relation + +lib/PTO/IR/VPTO.cpp +lib/PTO/Transforms/VPTOLLVMEmitter.cpp +lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp + no change for existing vlds/vsts dist tokens + extend vpack/vsunpack/vzunpack verifier/emitter only if the first implemented + fallback needs currently unsupported b64->b32 or floating-point carrier paths + +test/lit/vmi + add parser/verifier tests for dense laneStride attrs + add assignment tests for ext lane-stride facts + add fold/remat/sink tests for helper-driven rewrites + add vmi-to-vpto checks for UNPK/PK and vpack/unpack fallback + add negative tests for unsupported LS=4 b16/b32 compact load/store +``` + +Load/`ensure_layout` fold algorithm: + +```text +input shape: + %x0 = pto.vmi.load ... : !vmi.vreg + %x1 = pto.vmi.ensure_layout %x0 + : !vmi.vreg + -> !vmi.vreg + +preconditions: + the load result has no other use that requires the old layout + the load semantics are a compact logical stream + VMILayoutSupport says the target layout has a direct load lowering: + compact stream: + LS=2 b8/b16/b32 -> UNPK_B8/B16/B32 + LS=4 b8 -> UNPK4 + masks/passthroughs, if present, already have compatible assigned layouts + +rewrite: + replace the original load op in place, or create the replacement load at the + same insertion point and erase the old load + the replacement load result type is the ensure target type + all ensure users use the replacement load result + erase the ensure_layout + +output shape: + %x = pto.vmi.load ... : !vmi.vreg + +lowering: + vmi-to-vpto emits the corresponding vlds UNPK dist +``` + +This fold changes the assigned result layout of the existing load; it does not +clone the load at the helper use-site. If the original load has both contiguous +and lane-stride consumers, the fold must leave the helper in place unless a +separate rematerialization step has a safe-read proof to clone the load. + +### 1.4 Scenario Ownership + +Each optimization scenario has exactly one owning pass. Other passes may verify +or lower the resulting explicit IR, but should not solve the same rewrite again. + +| Scenario | Example shape | Owning pass | Non-owners | +|---|---|---|---| +| Assign a layout request | `ext -> store` where store wants contiguous | `vmi-layout-assignment` inserts explicit layouts/helpers | Assignment does not clone, fold, or lower | +| Direct load produces requested lane map | `load(contiguous) -> ensure_layout(lane_stride=2)` | `vmi-layout-fold` rewrites the original load result layout when UNPK support exists | Remat must not clone this load without safe-read proof | +| Direct store consumes lane map | `ensure_layout(lane_stride -> contiguous) -> store` | `vmi-layout-fold` rewrites the VMI store to consume the lane_stride source directly when direct compact-store support exists | `vmi-to-vpto` emits the actual `vsts PK/PK4` | +| Cheap producer can produce target layout | `broadcast -> ensure_layout(lane_stride=2)` | `vmi-layout-rematerialize` rebuilds broadcast with lane-stride result | Fold does not rebuild arbitrary producers | +| Cast chooses arity-reducing relation | `64xf16 -> 64xf32` or a supported narrowing with smaller strided result | `vmi-layout-assignment` chooses the cast source/result layout relation | Remat only handles later use-site requests; fold only handles adjacent load/store helpers | +| Cast can move materialization to cheap source | `ext/trunc -> ensure_layout(requested layout)` with source broadcast/load-fold case | `vmi-layout-rematerialize` rebuilds the cast with the requested relation | Assignment may already choose the self-preferred relation; fold only handles the load/store subcase | +| Layout-transparent op has ensured operands | `ensure(a), ensure(b) -> add` | `vmi-layout-sink-materialization` sinks matching helpers to the result | Remat handles the opposite shape `add -> ensure` | +| Surviving supported helper | `ensure_layout(contiguous <-> lane_stride)` after optimizations | `vmi-to-vpto` lowers to register pack/unpack | Earlier passes are allowed to leave it explicit | +| Unsupported helper or layout | `lane_stride=4 b16 compact store` | `pto-validate-vmi-layout-ir` rejects before lowering | `vmi-to-vpto` should not invent a repair | +| Multi-consumer value with incompatible layouts | one load feeds contiguous user and lane-stride user | baseline keeps helper; optional remat only with safe-read proof | Fold must not silently duplicate memory effects | + +Examples: + +```text +load fold, owned by vmi-layout-fold: + before: + %x0 = pto.vmi.load ... : contiguous + %x1 = pto.vmi.ensure_layout %x0 : contiguous -> lane_stride=2 + after: + %x1 = pto.vmi.load ... : lane_stride=2 + vmi-to-vpto: + %x1 = pto.vlds ... {dist = "UNPK_B8/B16/B32 or UNPK4"} + +store fold, owned by vmi-layout-fold: + before: + %x_c = pto.vmi.ensure_layout %x_ls : lane_stride=2 -> contiguous + pto.vmi.store %x_c, %dst + after: + pto.vmi.store %x_ls, %dst // VMI store consumes lane_stride source + vmi-to-vpto: + pto.vsts %x_ls, %dst {dist = "PK_B16/B32/B64 or PK4_B32"} + +broadcast remat, owned by vmi-layout-rematerialize: + before: + %b0 = pto.vmi.broadcast %s : contiguous + %b1 = pto.vmi.ensure_layout %b0 : contiguous -> lane_stride=2 + after: + %b1 = pto.vmi.broadcast %s : lane_stride=2 + +elementwise sink, owned by vmi-layout-sink-materialization: + before: + %a1 = ensure_layout %a0 -> lane_stride=2 + %b1 = ensure_layout %b0 -> lane_stride=2 + %c1 = pto.vmi.addf %a1, %b1 + after: + %c0 = pto.vmi.addf %a0, %b0 + %c1 = ensure_layout %c0 -> lane_stride=2 +``` + +## 2. IR Attribute Changes + +### 2.1 Extend `VMILayoutAttr` + +Current storage reuses `blockElems` as group-slot `lane_stride`. Generalization +should first split lane stride from block elems: + +```c++ +kind +factor +blockElems +slots +laneStride +``` + +Meaning by kind: + +```text +contiguous: + factor = 1 + blockElems = 1 + slots = 0 + laneStride >= 1 + +deinterleaved: + factor = F + blockElems = B + slots = 0 + laneStride >= 1 + +num_groups: + factor = G + blockElems = 1 + slots = K + laneStride >= 1 +``` + +Do not add a public `laneOffset` field in the current stage. The targeted +optimization only needs phase-zero strided dense layouts. A future +phase field is justified only when there is a concrete VMI value whose logical +lane map is intentionally non-zero-phase, for example a zero-copy +deinterleave/extract view that keeps the odd lanes in place or a narrowing +conversion whose consumer explicitly requires an odd-lane result. + +Recommended helpers: + +```c++ +bool isDense() const; +bool hasDenseLaneStride() const; +bool hasGroupSlotLaneStride() const; +int64_t getLaneStride() const; +VMILayoutAttr withLaneStride(int64_t stride) const; +``` + +Keep old constructor defaults source-compatible where possible: + +```c++ +getContiguous(ctx) +getDeinterleaved(ctx, factor, blockElems = 1, + laneStride = 1) +getGroupSlots(ctx, numGroups, slots = 0, laneStride = 1) +``` + +### 2.2 Parser And Printer + +Accepted dense spellings: + +```text +#pto.vmi.layout +#pto.vmi.layout + +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +Existing group-slot spellings remain valid: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Printing omits defaults: + +```text +lane_stride = 1 is omitted +``` + +### 2.3 Verifier + +Verifier rules: + +```text +all layouts: + laneStride > 0 + +contiguous: + factor == 1 + blockElems == 1 + slots == 0 + +deinterleaved: + factor in supported dense factors + blockElems > 0 + slots == 0 + +num_groups: + factor > 0 + slots >= 0 + blockElems == 1 +``` + +The verifier should not require every strided layout to fit one VPTO register. +Fit depends on the VMI type shape and element type, so it belongs in type +physicalization and op support checks. + +## 3. Physicalization Helpers + +### 3.1 Separate Element Carrier From Lane Map + +Replace the current shared helper shape: + +```c++ +getVMIPhysicalElementType(type) +``` + +with two concepts: + +```c++ +getVMILogicalStorageElementType(type) +getVMIPhysicalCarrierElementType(type, loweringKind) +``` + +Dense lane-strided values keep the VMI logical element type. The lowering may +represent the same lane map either as logical-element lanes or as wider carrier +slots when the selected VPTO instruction is a pack/unpack family. + +Logical-element lane representation: + +```text +!vmi.vreg + -> !pto.vreg<128xf16> physical register +``` + +Carrier-slot representation for pack/unpack lowering: + +```text +!vmi.vreg + -> low ui16 in each ui32 slot for vpack/PK_B32-style lowering + +!vmi.vreg + -> low ui8 in each ui32 slot for PK4_B32-style lowering +``` + +Group-slot packed stores also request a wider carrier in the specific lowering +path: + +```text +!vmi.vreg + group_store -> b32 carrier + PK4_B32 +``` + +Do not let dense `hasLaneStride()` imply unsigned-integer carrier widening +globally. Carrier widening is a property of a selected materialization or +load/store lowering, not of the VMI logical type itself. + +### 3.2 Physical Arity + +Add a dense lane-map helper: + +```c++ +struct DenseLaneMap { + int64_t deinterleaveFactor; + int64_t blockElems; + int64_t laneStride; +}; + +int64_t getPhysicalLaneForDenseLogicalLane(DenseLaneMap map, + int64_t logicalLane); +``` + +For a VMI vreg type: + +```text +lanesPerVPTO = getVPTOPhysicalLanes(elementType) +lanesInDensePart = ceil(N / F) with block-aware distribution +requiredLanes = O + (lanesInDensePart - 1) * LS + 1 +registersPerDensePart = ceil(requiredLanes / lanesPerVPTO) +physicalArity = F * registersPerDensePart +``` + +For the current stage, require full block divisibility for dense +deinterleaved strided layouts, matching existing direct lowering restrictions: + +```text +N % (F * B) == 0 +``` + +Relaxing tail handling is outside the current stage and should be enabled only +with an explicit materialization/lowering proof. + +## 4. Layout Support Interface + +Extend support queries to include dense strided layouts: + +```text +supportsResultLayout(op, resultIndex, layout) +supportsOperandLayout(op, operandIndex, layout) +supportsLayoutRelation(op, operandLayouts, resultLayouts) +``` + +The important change is relation support. Some ops are not independently +described by "operand supports layout X" and "result supports layout Y"; they +support specific pairs. + +Examples: + +```text +elementwise: + all dense operands/results must use identical dense layout key + +extf/extui/extsi: + source/result layouts must satisfy a widening relation. Assignment chooses + between the natural deinterleaved relation and the compact-result + lane-stride-source relation by comparing physical arity, not by matching a + concrete lane count such as 64. + +truncf/trunci: + source/result layouts must satisfy a narrowing relation. Assignment uses the + inverse relation conservatively: keep the natural deinterleaved-source to + contiguous-result relation unless arity or a supported consumer request + selects a strided result relation. Masked-store consumers may only use the + strided result relation when the value and mask can be assigned/materialized + to the same lane map. + +broadcast/group_broadcast: + result may use a dense layout only when the materialization lowering has an + explicit support case for that lane map + +load: + default result contiguous + producer rematerialization may create selected strided layouts if a direct + load/mask sequence can produce that lane map + +store: + memory effect is contiguous unless the op is an explicit logical interleave + store; a strided source requires store lowering support or ensure_layout +``` + +Assignment should still insert `ensure_layout` for incompatible use-local +requests. Rematerialization/fold can later remove it. + +### 4.1 Cast Relation Helper Shape + +Keep `getPreferredCastLayoutFact` as the assignment entry point for dense +widening and narrowing casts, but make the helper return the actual preferred +source/result relation for the current shape. Internally it first builds the +natural relation: + +```text +widen: + source contiguous + result deinterleaved=W + +narrow: + source deinterleaved=W + result contiguous +``` + +Then it computes the compact relation: + +```text +widen: + source contiguous, lane_stride=W + result contiguous + +narrow: + source contiguous + result contiguous, lane_stride=W +``` + +The compact relation is selected only when its source/result physical arities +match and it strictly reduces the relevant baseline arity: + +```text +widen: + physical_arity(compact result) < physical_arity(natural result) + +narrow: + physical_arity(compact source) < physical_arity(natural source) +``` + +If the compact relation does not win, the helper returns the natural relation. +`vmi-layout-assignment` calls this helper for `extf/extui/extsi` and +`truncf/trunci`, requests the returned source layout, and records the returned +result layout. + +The support query must validate the returned pair before assignment commits it: + +```text +supportsExtRelation(sourceTypeWithLayout, resultTypeWithLayout) +supportsTruncRelation(sourceTypeWithLayout, resultTypeWithLayout) +``` + +The validation step is a legality check, not a second optimizer. + +### 4.2 Current Framework Fit + +The existing assignment pass already has use-site requests. For example, +`pto.vmi.store` requests a contiguous source operand, and assignment can insert +`ensure_layout` when the stored value is assigned another layout. + +The dense-stride `ext` optimization should keep the same model: the cast op is +the layout-entry point and stores one preferred source/result relation. The old +preferred relation was: + +```text +extf: + request source contiguous + set result deinterleaved=W +``` + +The current stage keeps the existing single-preference framework and lets +`ext` choose one fact for the current op: + +```text +baseline fact: + source contiguous + result deinterleaved=W + +lane-stride fact: + source lane_stride=W + result contiguous +``` + +The `ext` support query chooses between these facts from op-local information: + +```text +conversion ratio W +target support for one selected hardware conversion part +physical arity of the natural result layout +physical arity of the compact contiguous result layout +requested result layout when a consumer materialization/remat path provides one +``` + +It does not inspect the defining source producer. If compact result arity is +strictly smaller than natural result arity and the target supports the +single-part relation, it selects the lane-stride fact. If it selects the +lane-stride fact and the source is not already in that layout, assignment +inserts an explicit source `ensure_layout`. Later passes either discharge that +helper by rematerializing/folding a concrete producer, lower it with a +registered pack/unpack materializer, or let validation reject the unsupported +relation. + +## 5. Widening Conversion Lowering + +Let: + +```text +W = result element storage bits / source element storage bits +``` + +For a dense source layout: + +```text +source lane_stride = LS +``` + +Single-part lowering is legal when: + +```text +LS % W == 0 +``` + +Then: + +```text +hardware part = 0 +result lane_stride = LS / W +``` + +The current stage only emits the zero-phase single-part conversion. +`vcvt ODD` remains necessary for full packed contiguous conversion, but that is +handled by the existing multi-part relation: + +```text +source contiguous, lane_stride=1 +result deinterleaved=W +``` + +Do not add a phase field merely to name that existing ODD instruction. Add a +phase field only when an assigned VMI layout needs to represent a concrete +zero-copy value/view already resident in odd/non-zero-phase lanes. + +The support query for the conversion should accept the pair only when the +requested result layout equals this computed result lane map, including +deinterleave/block fields. + +The support query should expose helpers for both legal facts, but assignment +chooses one immediately: + +```text +baseline fact: + source contiguous + result deinterleaved=W + natural result arity = physical_arity(result deinterleaved=W) + +lane-stride fact: + result contiguous + source same dense shape with lane_stride = W + compact result arity = physical_arity(result contiguous) +``` + +Assignment uses this deterministic rule: + +```text +if compact result arity < natural result arity + and the lane-stride fact is supported: + choose lane-stride fact +else: + choose baseline fact +``` + +For example, for `f16 -> f32`, the `extf` op chooses +`source lane_stride=2 -> result contiguous` for `64xf32`, because the compact +result has one physical chunk while the natural `deinterleaved=2` result has two +physical chunks. For `128xf32` and `256xf32`, both layouts have the same result +arity, so assignment chooses the natural `deinterleaved=2` result. The source +producer is handled by the explicit source `ensure_layout` and later +fold/rematerialization; it is not part of the cast support query. + +Current contiguous widening remains a separate legal relation: + +```text +source dense contiguous, lane_stride=1 +result deinterleaved=W, lane_stride=1 +``` + +Implementation steps: + +1. Factor conversion ratio calculation by storage bit width. +2. Add helper that computes the natural result layout and its physical arity. +3. Add helper that computes the compact result layout, required source + lane-stride layout, and compact result physical arity. +4. Teach `VMIToVPTO` conversion lowering to emit only the selected hardware + part when the relation is single-part. +5. Keep existing multi-part lowering for contiguous-to-deinterleaved cases. +6. Add diagnostics when an assigned conversion layout pair has no lowering. + +Hardware part names should be abstracted: + +```text +W=2: + part 0 -> EVEN + part 1 -> ODD + +W=4: + part 0..3 -> target-specific conversion part names or the existing sequence +``` + +Do not special-case f16/f32 in the matcher. The type only determines `W` and +the concrete VPTO conversion opcode. + +## 6. Narrowing Conversion Lowering + +Let: + +```text +W = source element storage bits / result element storage bits +``` + +For a single-part narrowing relation: + +```text +result lane_stride = source lane_stride * W +hardwarePart = 0 for the current stage +``` + +The narrowing assignment relation is the inverse of widening, but it must not +blindly choose a lane-stride result. Build two facts: + +```text +baseline fact: + source deinterleaved=W + result contiguous + natural result arity = physical_arity(result contiguous) + +lane-stride fact: + source contiguous + result contiguous, lane_stride=W + strided result arity = physical_arity(result lane_stride=W) +``` + +Then choose a strided result only when it is justified: + +```text +if strided result arity < natural result arity + and the lane-stride fact is supported: + choose lane-stride fact +else if a consumer/requested result layout is the strided result + and the lane-stride fact is supported: + choose or rematerialize lane-stride fact +else: + choose baseline fact +``` + +This keeps trunc symmetric with ext while avoiding the earlier mistake of +producing lane_stride solely because the operation is a narrowing cast. A +consumer may still request or preserve a strided result. For example, an +ordinary store with direct `PK` support can consume a supported lane-stride +result, and rematerialization/fold may keep that relation. A masked store may +do so only when the mask can be assigned/materialized to the same lane map. + +Implementation steps: + +1. Share ratio, dense-factor, lane-map, and physical-arity helpers with + widening. +2. Add helper that computes the natural source/result relation and result + arity. +3. Add helper that computes the strided-result relation and result arity. +4. Add support query for valid narrowing layout pairs. +5. Teach assignment/rematerialization to select the strided fact for explicit + result requests, direct compact-store consumers, or true arity reductions. +6. Lower single-part narrowing directly when the target has a part-selecting + narrow instruction. +7. Preserve existing deinterleaved-to-contiguous narrowing for the packed full + result case. + +This is the same family as the recently discussed `d4 -> c -> d2 -> vcvt -> c` +optimization: if a cast op has a direct source/result layout relation, +assignment/rematerialization should expose that relation before lowering. + +## 7. Ensure-Layout And Rematerialization + +### 7.1 `ensure_layout` + +`ensure_layout` remains the explicit use-site materialization op. + +Verifier/lowering policy: + +```text +same source and target dense lane map: + fold away + +known dense relation: + lower contiguous <-> lane_stride through register pack/unpack when supported + lower contiguous/deinterleaved relations through existing intlv/dintlv paths + +producer can rematerialize target layout: + rematerialization should replace ensure_layout(producer) + +unknown relation: + reject before vmi-to-vpto +``` + +Avoid adding a generic "any dense layout to any dense layout" promise unless the +target really has a lowering for it. + +### 7.2 Rematerialization + +The current checked-in `vmi-layout-rematerialize` cheap producers are: + +```text +data: + VMIExtFOp / VMIExtSIOp / VMIExtUIOp when the source layout can be + materialized for the requested result relation + VMIFmaOp + binary layout-transparent ops: + addf/addi/subf/subi/mulf/muli/divf/minf/maxf/andi/ori/xori/shli/shrui + unary layout-transparent ops: + negf/absf/absi/sqrt/exp/ln/relu/not + VMIConstantOp only when the DenseElementsAttr is a splat + VMIBroadcastOp + VMIIotaOp + +mask: + VMICreateMaskOp + VMICreateGroupMaskOp + VMIConstantMaskOp + +special rewrite: + selected VMITruncFOp / VMITruncIOp through source/result ensure_layout when + the cast relation is a supported narrowing relation +``` + +Not included as cheap producers in the current pass: + +```text +load / masked_load / group_load / group_slot_load / group_broadcast / +group_broadcast_load / store / reduce / control-flow ops +``` + +Loads need a separate policy. `load -> ensure_layout` should be folded in +`vmi-layout-fold` when one original load can directly produce the requested +layout. A normal load should not be cloned/rematerialized unless a later safe-read +proof explicitly permits that clone. + +Relationship between cheap producers and dense `lane_stride`: + +```text +assignment: + creates the target layout request explicitly, usually as ensure_layout(... -> + lane_stride) or as a cast source/result relation. For casts, assignment may + itself choose the arity-reducing lane-stride relation; remat only reacts to + later use-site layout requests. + +rematerialize: + does not choose lane_stride as a preference + only consumes the explicit helper/request and rebuilds the producer with the + requested lane_stride result type when the producer is cheap and locally legal +``` + +Required rematerialize changes for dense `lane_stride`: + +```text +materializeDataLayout: + no special producer logic, but canMaterializeDataLayout must understand + contiguous <-> lane_stride through register pack/unpack + +splat constant / broadcast / iota: + rebuild the op with the requested lane_stride result type + lowering later materializes that layout directly or through ensure_layout + +layout-transparent unary/binary/fma: + rebuild the op with the requested lane_stride result type + materialize each operand to the same lane_stride layout before rebuilding + this relies on canMaterializeDataLayout for operand conversions + +widening ext: + update getWidenSourceLayoutForResultLayout so a requested result layout derives + the required source lane_stride: + result contiguous, W=2 -> source lane_stride=2 + result lane_stride=R, W=2 -> source lane_stride=2*R + remat then inserts/uses source ensure_layout and rebuilds ext with the + requested result layout + +narrowing trunc: + add getNarrowSourceLayoutForResultLayout or an equivalent relation helper. + For a requested result lane_stride=R and narrowing ratio W, derive the source + layout that can produce that result with a selected hardware part: + result lane_stride=W, W=2 -> source contiguous + result lane_stride=R, W=2 -> source lane_stride=R/W when divisible + remat then inserts/uses source ensure_layout and rebuilds trunc with the + requested result layout + +trunc source-ensure rewrite: + extend the existing source-ensure rewrite to recognize lane_stride narrowing + relations for VMITruncFOp and VMITruncIOp, not only deinterleaved narrowing + relations + +mask producers: + only participate after mask layout support defines the corresponding + lane-stride or predicate-granularity relation; otherwise unchanged +``` + +Example: + +```text +before remat: + %b0 = pto.vmi.broadcast %s : !vmi.vreg<64xf16, contiguous> + %b1 = pto.vmi.ensure_layout %b0 + : contiguous -> contiguous, lane_stride=2 + %y = pto.vmi.extf %b1 : f16 -> f32 + +after remat: + %b1 = pto.vmi.broadcast %s + : !vmi.vreg<64xf16, contiguous, lane_stride=2> + %y = pto.vmi.extf %b1 : f16 -> f32 +``` + +This removes a register layout materialization and lets `vmi-to-vpto` lower the +ext as the single selected conversion part. It is still driven by the explicit +layout request; remat does not inspect sibling consumers or choose lane_stride by +itself. + +Do lane-stride cast rematerialization only in these cases: + +```text +required shape: + cast result is followed by ensure_layout to a requested dense result layout + widening or narrowing ratio W > 1 + the requested source/result layout pair is accepted by the cast relation + helper + the cast with that source/result layout can lower as one selected conversion + part or the existing multi-part relation + +acceptance/safety gate: + the source-side lane_stride request must be discharged by a concrete local + rewrite, not merely moved from result side to source side + accepted cases: + the source already has the required lane_stride + the source producer is in the checked-in cheap producer list and can be + rebuilt with the required lane_stride + the source is load -> ensure_layout and vmi-layout-fold can replace it with + a single original-position layout-aware VMI load + a layout-transparent chain can be sunk/rematerialized until one of the above + concrete producer cases is reached + +do not apply: + result consumer already accepts the natural cast layout + requested cast layout relation is unsupported + source is an ordinary load with other incompatible consumers and no safe-read + proof to clone it + the rewrite only moves an expensive materialization from result side to source + side without exposing a direct lowering +``` + +Typical accepted shapes: + +```text +broadcast -> ext -> ensure_layout(contiguous) -> store + remat broadcast as lane_stride=W + ext lowers with one conversion part + no source-side ensure_layout remains + +load -> ensure_layout(lane_stride=W) -> ext -> store + fold load into a layout-aware VMI load + vmi-to-vpto later emits the matching UNPK dist + ext lowers with one conversion part + no extra load is cloned + +elementwise cheap chain -> ext -> ensure_layout(contiguous) + remat/sink the chain to lane_stride=W only when the chain reaches a concrete + cheap producer or direct load-fold case + +trunc -> ensure_layout(lane_stride=W) -> compact store + remat/rebuild trunc with the requested lane_stride result when the source + layout relation is supported + store fold may then consume the lane_stride result directly + +trunc -> ensure_layout(lane_stride=W) -> masked_store + only accepted after mask layout assignment can provide the same lane map for + the predicate; otherwise keep the conservative contiguous masked-store path +``` + +## 8. Broadcast And E2B Interaction + +Do not encode E2B in `lane_stride`, and do not define +`vmi.group_broadcast_load` in terms of E2B. The VMI operation is a logical +fused memory operation: + +```text +for each logical group g: + scalar = source[offset + g * source_group_stride] + for each lane i in group g: + result[i] = scalar +``` + +The result layout is assigned separately. It may be contiguous, +deinterleaved, or dense lane-strided if the consumer asks for that lane map and +the target support table accepts it. E2B is only one VPTO lowering strategy for +a restricted subset of this logical operation. + +The layering should be: + +```text +logical group broadcast load + -> assigned dense layout, possibly lane-strided + -> support query chooses a lowering strategy + -> selected VPTO dist, if any +``` + +For the current E2B strategy, the support query checks: + +```text +source is direct memory +source_group_stride is constant 1 +num_groups is a multiple of 8 +element storage width +logical group size derived from num_groups +assigned result layout: + contiguous for the direct packet size + or deinterleaved=2, block_elems=1 for the split packet size +``` + +Then it may choose an E2B packet: + +```text +b16 contiguous: direct 1 -> 16 packet +b16 deinterleaved=2: two logical halves / 1 -> 32 reuse +b16 dense lane_stride=2: direct phase-zero strided consumer packet +b32 contiguous or strided: target-specific packet size +``` + +If those conditions do not hold, the operation is still a valid VMI semantic if +some other lowering exists, such as `group_slot_load + group_broadcast`, scalar +loads plus broadcast, or future target-specific broadcast-load support. The +failure is only "this E2B lowering strategy is not applicable", not "the VMI +operation means E2B". + +Concrete implementation plan for lane-stride `group_broadcast_load`: + +```text +include/PTO/Transforms/VMILayoutSupport.h +lib/PTO/Transforms/VMILayoutSupport.cpp + +1. Split semantic support from E2B strategy checks: + + getGroupBroadcastLoadSupport(capabilities, op) + try getE2BGroupBroadcastLoadSupport(capabilities, op) + if success: + return {kind = E2BVlds} + return failure("no registered group_broadcast_load lowering strategy; " + "E2B rejected because ...") + + getE2BGroupBroadcastLoadSupport(capabilities, op) + contains the current E2B constraints: + source is !pto.ptr direct memory + element width is b16 or b32 + source_group_stride is constant 1 + num_groups is a multiple of 8 + group size matches direct or split E2B packet size + result layout is contiguous or deinterleaved=2/block_elems=1 + result has full physical chunks + +2. Keep VMIGroupBroadcastLoadSupportKind strategy-specific: + E2BVlds means "lower this VMI semantic using E2B" + It must not be used as the definition of the VMI op. +``` + +```text +lib/PTO/Transforms/VMILayoutAssignment.cpp + +3. Rename strategy helpers so the direction is clear: + + isE2BGroupBroadcastLoadCandidate + -> isE2BGroupBroadcastLoadStrategyApplicable + + getPreferredGroupBroadcastLoadLayout + -> getPreferredE2BGroupBroadcastLoadLayout + +4. Fusion from group_slot_load + group_broadcast to group_broadcast_load remains + guarded by E2B applicability. If E2B is not applicable, do not create a + fused group_broadcast_load merely because the VMI semantic would be valid. + That avoids producing an op with no registered lowering strategy. + +5. Layout assignment for an explicit group_broadcast_load uses the support + query: + if E2B strategy applies: + assign the E2B-preferred result layout + else: + leave the op to validation unless a fallback strategy is added +``` + +```text +lib/PTO/Transforms/VMIToVPTO.cpp + +6. Replace duplicated local E2B legality checks with: + support = getGroupBroadcastLoadSupport(capabilities, op) + switch support.kind: + E2BVlds: + emit the existing E2B packet sequence + + The E2B lowering code may still assert/recheck structural invariants needed + for indexing, but user-facing diagnostics should come from the support query. + +7. Diagnostics must name the strategy: + good: "group_broadcast_load has no registered lowering strategy; E2B + rejected because source_group_stride is not constant 1" + bad: "group_broadcast_load requires constant unit source_group_stride" + + The second form is only valid inside an E2B-specific diagnostic. +``` + +Required group-broadcast-load tests: + +```text +E2B positive: + explicit group_broadcast_load with b16/b32, stride=1, matching group size, + and assigned contiguous/deinterleaved result layout + CHECK vmi-to-vpto emits E2B_B16/E2B_B32 + +E2B strategy rejection: + source_group_stride != 1, wrong group size, or unsupported element width + CHECK validation/lowering diagnostic says no registered lowering strategy and + reports E2B as the rejected strategy + +fusion guard: + group_slot_load + group_broadcast shape that is not E2B-applicable + CHECK assignment does not fuse it into group_broadcast_load + +semantic boundary: + explicit group_broadcast_load that is not E2B-applicable + CHECK failure wording does not redefine the op as E2B and does not imply the + logical VMI semantic itself is E2B +``` + +This keeps broadcast optimization generic across type width and layout, instead +of hardcoding one `ComputeY1ToFP8` scale pattern. + +## 9. Tests + +Use the following as the coverage matrix for current-stage support plus the +masked-store and group-broadcast-load follow-up items. It is not a separate +list of all remaining implementation work. + +Parser/verifier: + +```text +parse/print contiguous lane_stride +parse/print deinterleaved + block_elems + lane_stride +``` + +Physicalization: + +```text +64xf16 contiguous lane_stride=2 has one physical 128xf16 part +ui16 contiguous lane_stride=2 may lower through low ui16 in ui32 carrier slots +when the selected materialization is vpack/PK_B32 +ui8 contiguous lane_stride=4 may lower through low ui8 in ui32 carrier slots +when the selected materialization is PK4_B32 +65xf16 contiguous lane_stride=2 is rejected by direct full-chunk-only paths, or +covered only by an arity-changing materialization test outside this discussion +group-slot ui8 lane_stride=4 keeps existing carrier lowering behavior +``` + +Conversion lowering: + +```text +f16 lane_stride=2 -> f32 contiguous emits one EVEN conversion +bf16 lane_stride=2 -> f32 contiguous follows the same relation +ui8 lane_stride=2 -> ui16 contiguous follows W=2 +ui8 lane_stride=4 -> ui32 contiguous follows W=4 when target supports it +contiguous f16 -> deinterleaved=2 f32 still emits EVEN + ODD +f32 contiguous -> f16 lane_stride=2 emits the selected narrowing part when the +assigned relation is supported +f32 deinterleaved=2 -> f16 contiguous keeps the existing packed full-result +narrowing relation +ui16 lane_stride=2 -> contiguous can materialize with vpack 32->16 carrier path +ui8 lane_stride=4 -> contiguous can materialize with two vpack stages +``` + +Assignment/rematerialization: + +```text +extf records a strided dense source relation when compact result arity is +smaller than natural result arity +extf 64xf16 -> 64xf32 chooses source lane_stride=2, result contiguous +extf 128xf16 -> 128xf32 chooses result deinterleaved=2 +extf 256xf16 -> 256xf32 chooses result deinterleaved=2 +truncf records a strided result relation only when the conservative +self-preference/support rule or a supported consumer request selects it; it +does not choose lane_stride solely because the op narrows +layout-transparent op propagates the same strided layout through operands/result +ensure_layout is folded when source and target lane maps match +rematerialization clones a cheap broadcast for two different dense layouts +``` + +End-to-end assignment cases: + +```text +contiguous load -> ext -> contiguous store: + uses lane_stride only when the source ensure_layout can be folded to the + original load, rematerialized from a cheap producer, or lowered by a supported + register materializer + +cheap broadcast -> ext -> contiguous store: + rematerializes broadcast as lane_stride=2 and lowers ext with one EVEN part + +producer -> ext -> deinterleaved reduce: + keeps source contiguous and result deinterleaved=2 + +cheap producer -> ext feeding both store and reduce: + keeps shared deinterleaved path for reduce and rematerializes a contiguous + result path for store only through the checked cheap-producer remat path + +group_broadcast_load -> ext -> contiguous consumer: + chooses lane_stride only if group_broadcast_load supports that lane map +``` + +Negative tests: + +```text +assigned ext layout pair where LS % W != 0 and no multi-part relation exists +assigned trunc layout pair where result lane_stride is not compatible with the +narrowing ratio +ordinary dense op with mismatched lane_stride operands +store consuming strided dense layout without a supported store/materialization +masked_store consuming lane_stride value with a stale contiguous user mask is +rejected or kept on the conservative contiguous path +``` + +## 10. Suggested Patch Order + +1. Add attr fields, parser/printer, verifier, and round-trip tests. +2. Split dense lane-map physicalization from group-slot carrier packing. +3. Update physical arity/unpack helpers for dense lane stride. +4. Extend support queries and assignment layout keys. +5. Implement widening arity-driven self-preference, single-part relation, and + tests. +6. Implement narrowing inverse relation support, consumer-request handling, and + tests. +7. Teach rematerialization/fold about exact dense lane-map equality. +8. Add broadcast/E2B recognition improvements that consume assigned lane maps. + +Each step should keep existing group-slot `lane_stride` tests passing. The first +functional optimization can be the `f16/bf16 lane_stride=2 -> f32 contiguous` +single-part conversion, but the IR and helper changes should already be generic +over type width and lane-map fields. diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md new file mode 100644 index 0000000000..865adf413a --- /dev/null +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -0,0 +1,2320 @@ +# VMI Layout Assignment Implementation Plan + +本文是 `vmi-layout-assignment` 和 `vmi-to-vpto` 的实现计划。它配套 +`vmi-layout-assignment-lowering-design.md`,并以 +`vmi-layout-lowering-cases.md` 为测试和验收来源。 + +不使用早期 VMI 草稿作为设计输入。 + +## 1. Pipeline + +Recommended pass pipeline: + +```text +pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-fold // optional optimization over remat-exposed helpers + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto + -> canonicalize/cse + -> existing VPTO lowering/codegen +``` + +Only `vmi-layout-assignment` is required for the first legal implementation. +The optimization passes may be introduced one by one. Their contract is that +they consume legal layout-assigned VMI IR and produce legal layout-assigned VMI +IR; they never move a hidden decision into `vmi-to-vpto`. + +Pass responsibilities: + +```text +pto-validate-vmi-ir: + verify surface VMI has no physical VPTO layout dependency + reject public/external VMI ABI unless explicitly enabled + +vmi-layout-assignment: + solve hard value layout constraints + choose explicit layouts visible in IR + insert ensure_layout / ensure_mask_layout / ensure_mask_granularity helpers + make internal function boundary layouts explicit + rewrite VMI types with layout attrs + +canonicalize/cse: + remove dead helpers and merge identical cloned producers where MLIR legality + permits + +vmi-layout-fold: + fold use-site materialization into consumers that can directly consume the + source layout while preserving the same logical effect + example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become + a store of deinterleaved=2 when the store has a layout-aware vstsx2 INTLV + lowering + current implementation: pto.vmi.store and the value operand of + pto.vmi.masked_store when the existing mask arity matches, fed by + ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 + uses the store's vstsx2 INTLV lowering; factor=4 is still store-local, but it + materializes through physical interleave before vsts. + +vmi-layout-rematerialize: + replace explicit ensure_* helpers with cloned cheap layout-polymorphic + producers when the clone directly creates the requested result type + current implementation: splat pto.vmi.constant, pto.vmi.broadcast, + pto.vmi.iota, selected layout-transparent data ops, widening + pto.vmi.ext{f,si,ui}, pto.vmi.create_mask, pto.vmi.create_group_mask, and + pto.vmi.constant_mask. Relation-aware remat rewrites result-side + ensure_layout through layout-transparent producers and widening ext + producers, leaving any newly exposed producer-side helpers for the following + vmi-layout-fold. + not included in the first implementation: load, group_load, masked_load, + group_slot_load, and group_broadcast; those require separate memory, + execution-count, or source-layout proof before they can be rematerialized + +vmi-layout-sink-materialization: + move ensure_layout across pure layout-transparent elementwise chains when the + rewritten IR reduces materialization overhead and keeps every op locally legal + current implementation: sink two identical operand ensure_layout helpers + across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, three + identical operand ensure_layout helpers across fma, or one source + ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, producing + one result ensure_layout. It also sinks compare data helpers to one result + ensure_mask_layout, and sinks select only when both selected values and the + mask carry matching explicit helpers. Matching ensure_mask_layout or + ensure_mask_granularity helpers are sunk across mask_and/mask_or/mask_xor/ + mask_not, producing one result mask helper. It does not sink through cast, + load, store, reduce, group_broadcast, or control-flow ops. + +vmi-legalize-arith-select: + restore scalar-condition arith.select with VMI result type back to scf.if + after canonicalize; canonicalize may fold simple scf.if into arith.select, + but VMI values must not cross non-VMI semantic ops before vmi-to-vpto + +pto-validate-vmi-layout-ir: + verify every VMI data/mask value has layout + verify every VMI value has an assigned layout and every non-local lowering + choice has been serialized explicitly + verify helper ops have supported materialization paths. Current + implementation checks `ensure_layout`, `ensure_mask_layout`, and + `ensure_mask_granularity` at the layout gate, so unsupported helper + materializations fail before `vmi-to-vpto`. It also checks the first + semantic local lowering families, non-contiguous + `pto.vmi.store`, block8 + `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots + `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, + explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, + `pto.vmi.extf`, `pto.vmi.bitcast`, and histogram family ops at the layout gate. + +vmi-to-vpto: + use OneToN type conversion + lower only from current-op attrs/operands, operand/result layouts, and helper + ops + emit VPTO or precise unsupported diagnostic +``` + +### 1.1 Hard Constraints Versus Optimizations + +Hard legalization answers "can this program be lowered correctly?" It is +allowed to be conservative: + +```text +%w = pto.vmi.extf %a // natural layout deinterleaved=2 +%t1 = pto.vmi.mulf %w, %k1 // layout-transparent, stays deinterleaved=2 +%t1_c = pto.vmi.ensure_layout %t1 // hard store contract wants contiguous +pto.vmi.store %t1_c, %OUT1 +%w_c = pto.vmi.ensure_layout %w +pto.vmi.store %w_c, %OUT2 +``` + +This is a correct legal shape. The contiguous action is explicit at each store +use, and `vmi-to-vpto` lowers the helper with register materialization such as +`vintlv` before ordinary `vsts`. + +Optimization answers "can the same external effect be cheaper?" A fold pass +may rewrite the two store uses to consume the deinterleaved values directly: + +```text +pto.vmi.store %t1, %OUT1 // value type still says deinterleaved=2 +pto.vmi.store %w, %OUT2 +``` + +This optimized shape is legal only because `pto.vmi.store` has enough local +information to lower a `deinterleaved=2` f32 value to row-major memory, for +example with `vstsx2 INTLV_B32`. The optimization does not require +`vmi-to-vpto` to inspect `%w`'s producer or the sibling store. + +The split gives later passes room to improve layout choices: + +```text +hard pass: + guarantee legality with explicit ensure_* helpers + +optimization passes: + remove, fold, clone, or sink helpers when the optimized IR is still locally + deterministic + +vmi-to-vpto: + physicalize exactly the IR it sees, with no global planning +``` + +## 2. Files To Add Or Update + +Expected implementation files: + +```text +include/PTO/IR/VMITypes.td +include/PTO/IR/VMIOps.td +include/PTO/IR/VMIAttrs.td +lib/PTO/IR/VMI.cpp + +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/PTOValidateVMIIR.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +small layout fact/materialization helpers under lib/PTO/Transforms + +test/lit/vmi/vmi_layout_assignment_*.pto +test/lit/vmi/vmi_to_vpto_*.pto +test/vpto/cases/vmi/*/ +``` + +Exact names may follow project conventions, but the layering should remain: + +```text +IR definitions + -> validation + -> assignment + -> OneToN lowering + -> lit and sim tests +``` + +## 3. IR Types And Attributes + +### 3.1 Layout Attribute + +Represent layout as a closed attribute family: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +C++ form: + +```c++ +enum class VMILayoutKind { + Contiguous, + Deinterleaved, + GroupSlots, +}; + +struct VMILayoutKey { + VMILayoutKind kind; + int64_t deinterleaveFactor = 1; + int64_t blockElems = 1; + int64_t numGroups = 0; + int64_t slots = 0; + int64_t laneStride = 1; +}; +``` + +Verifier rules: + +```text +contiguous: + no extra parameters + +deinterleaved: + F > 1 + B > 0 + direct full-chunk lowerings require N % (F * B) == 0 + +group_slots: + G > 0 + K > 0 + G % K == 0 + K fits in one physical vreg for element type + LS > 0 +``` + +Parser compatibility during migration: + +```text +#pto.vmi.layout +``` + +is the lowering contract for group-slot values. The parser still accepts +`#pto.vmi.layout` as a legacy spelling for the pre-design +implicit group layout, but `vmi-to-vpto` support queries require explicit slots. +New `vmi-layout-assignment` output must print one of: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +so `vmi-to-vpto` can lower from the assigned type without reconstructing group +slot placement from producer or consumer context. + +`lane_stride` is counted in logical element-sized physical slots and records a +regular gap between stored group slots. It is used for carrier-style packed +stores such as `ui8` group slots lowered through b32 `PK4_B32`. + +The current implementation treats this as a group-slot property. The dense +generalization is tracked separately in +`vmi-lane-stride-generalization-implementation.md`; it requires splitting dense +lane-map stride from group-slot carrier packing before `lane_stride` can be used +on `contiguous` or `deinterleaved` layouts. + +### 3.2 VMI Types + +Surface: + +```text +!pto.vmi.vreg +!pto.vmi.mask +``` + +Layout-assigned: + +```text +!pto.vmi.vreg> +!pto.vmi.mask> +``` + +Surface VMI types are legal before assignment. Layout-assigned VMI types are +required after assignment. + +### 3.3 Explicit Lowering Carriers + +Lowering decisions are carried by the current op and its types, not by a +separate lowering-plan string. The allowed carriers are: + +```text +op attrs and operands +operand/result VMI layouts +mask granularity and mask layouts +helper ops such as ensure_layout / ensure_mask_layout +cloned or rematerialized producers +diagnostics for unsupported shapes +``` + +If assignment made a non-local choice by inspecting producers, users, sibling +users, control flow, callees, or memory context, it must rewrite the IR so that +the final choice is visible through those carriers before `vmi-to-vpto`. + +Local-decision table for the current implementation: + +```text +op local decision inputs +group_load result layout, num_groups, row_stride, source type +group_slot_load result group_slots layout and source_group_stride +group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics +group_broadcast source/result layouts and num_groups +truncf source/result layouts and element widths +dhist/chist acc/source/mask/result layouts and target capability +ensure_layout always carries source/result layouts +ensure_mask_layout always carries source/result layouts +ensure_mask_granularity always carries source/result granularities +``` + +Layout/attr-only decisions today: + +```text +load result layout plus full chunk or shaped memref proof +group_store source group_slots layout plus explicit output stride +masked_load explicit passthrough, mask layout, and memory proof +masked_store/select operand/result layouts plus mask granularity +dense extf/truncf source/result layouts and element widths +``` + +Implementation rule: + +```text +validate-assigned-vmi validates assigned layouts, mask granularity, boundaries, +and helper placement. +vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. +If a layout/attr-only op later gains a second legal lowering that cannot be +distinguished from current-op information, that lowering must be represented by a +new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. +Unsupported shapes that have no explicit materialization/lowering path still +diagnose through their specific capability check rather than failing with a generic +missing-lowering +error. +``` + +Examples of forbidden recovery in `vmi-to-vpto`: + +```text +group_reduce_add{f|i} cannot walk to a load/group_load producer to choose + two-vlane parity versus block8. +group_store cannot inspect the group_reduce producer; it consumes only the + assigned source layout and explicit stride. +group_broadcast cannot inspect sibling users to decide whether to rematerialize. +masked_load cannot inspect the mask producer to prove memory safety. +func.call cannot inspect the callee body to decide physical function layout. +``` + +## 4. VMI Surface Ops Required By Cases + +Initial op set from the case catalog: + +```text +load +group_load +group_slot_load +store +masked_store + +create_mask +create_group_mask + +extf +truncf +extsi +extui +trunci +addf +addi +mulf +select +broadcast + +group_reduce_addf +group_reduce_addi +group_broadcast +group_store +dhist +chist + +ensure_layout // internal +ensure_mask_layout // internal +ensure_mask_granularity // internal +``` + +Type policy before lowering: + +```text +storage / memory boundary: + f8-like, i8, f16, i16, f32, i32 may appear as load/store element types when + the target memory instruction supports the physical width. + +cast boundary: + f8-like may appear as extf/truncf source or destination. + i8 may appear as extsi/extui/trunci source or destination. Signedness is an + op semantic, not a VMI type spelling. + Current VPTO lowering supports 32-bit integer narrowing to unsigned i8 + storage, matching the available VCVTII s32/u32 -> u8 forms; signed i8 + narrowing needs a separate target lowering. + +compute / accumulator: + floating compute baseline: f16/f32, with reassoc required for reductions + that lower through pair-wise VPTO reductions. + integer compute baseline: i32 for grouped reduction; i8/i16 storage must + first cast to i32 because integer reduction instructions widen narrow inputs. + f8/i8 are not baseline accumulator/compute types. Supporting direct 8-bit + compute requires a target capability entry and a separate lowering family. +``` + +Important semantic split: + +```text +load: + pointer sources must load full physical chunks directly. Partial logical + loads require a shaped memref proof or a future guarded/scratch fallback. + +group_load: + loads group_size data elements per group + +group_slot_load: + loads one scalar per group and produces group_slots +``` + +## 5. Layout Fact Helpers And Ensure-Based Optimization Hooks + +Do not implement a target-aware lowering-plan registry shared by assignment and +lowering. The shared contract is the IR itself: assigned VMI layouts, explicit +`ensure_layout` / `ensure_mask_layout` / `ensure_mask_granularity` helpers, +semantic op attrs/operands, and target capability diagnostics. + +Small pure helpers are still useful when they remove duplicated layout math. +They must return semantic layout facts, not VPTO instruction plans, costs, +clone decisions, or multi-user plans. + +Keep the support layer small. A query belongs in `VMILayoutSupport` only when +at least two stages need the same fact and a mismatch would create an +assignment-vs-lowering bug. Current valid shared facts are: + +```text +cast layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: f32->f8 must see deinterleaved=4 source and contiguous result in + every stage. + +group_reduce layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: S=2*VLaneElems means deinterleaved=2 source/mask and + group_slots(G, slots=8) result in every stage. + +histogram layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: dhist requires contiguous Nxui8 source, contiguous b8 mask, and + contiguous 256xui16 acc/result. chist uses the same layout fact but also + requires a target capability that classifies CHISTv2 cumulative range + semantics. + +layout materialization support: + shared by layout validation, vmi-to-vpto, and helper-based optimizations. + Example: ensure_layout from deinterleaved=2 f32 to contiguous f32 is the same + materialization whether it survives to lowering or is folded into a store. + +contiguous store support: + shared by fold-consumers and vmi-to-vpto because both must preserve the same + row-major memory effect when consuming a non-contiguous value. +``` + +Do not add a support query for a single private branch such as "this exact op +uses this exact VPTO mnemonic". Keep that branch in the lowering pattern until +another stage needs the same semantic fact. This prevents `VMILayoutSupport` +from becoming a second copy of the lowering pass. + +```c++ +struct VMICastLayoutFact { + VMICastLayoutKind kind; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t factor; +}; + +struct VMIGroupReduceLayoutFact { + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize; + int64_t vlaneElems; +}; + +FailureOr +getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType); + +FailureOr +getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups); + +LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason); +``` + +Baseline assignment uses these helpers only to produce assigned layouts and +use-site helpers. It does not clone producers, rematerialize cheap ops, choose +memory-fused layouts by cost, or specialize private function signatures for +performance. + +Optimization passes are deliberately helper-driven: + +```text +fold-consumers: + input shape: ensure_layout feeding a layout-aware consumer. + support query: can this consumer preserve the same logical memory effect from + the source layout? + output shape: the consumer directly uses the source value. + +rematerialize: + input shape: cheap producer feeding ensure_layout / ensure_mask_layout. + support query: can the cloned producer directly create the requested type? + output shape: a cloned producer at the use. + +sink-materialization: + input shape: pure elementwise op whose operands are matching ensure_* helpers. + support query: can the result helper be materialized if it remains? + output shape: the op runs in the source layout and one helper remains on the + result. +``` + +These passes may improve multi-consumer cases without asking assignment to solve +a global cost problem. Assignment guarantees a legal baseline with helpers; +optimization removes or moves those helpers locally when the rewritten IR still +contains enough information for `vmi-to-vpto`. + +Implementation-relevant layout facts: + +```text +dense store: + requests contiguous source. If the value is assigned deinterleaved, + assignment inserts ensure_layout at the store use. A later optimization may + fold ensure_layout + store into a layout-aware VMI store. `vmi-to-vpto` + later lowers that explicit store contract. + +data/mask helper materialization: + identity conversions are always legal. + contiguous <-> deinterleaved=2/4 is legal only when source/result physical + arity and physical chunk shapes make the same logical value materializable. + unsupported conversions remain explicit diagnostics. + +group_slot_load: + assigned result layout is group_slots(G, slots=8) for packed slots or + group_slots(G, slots=1) for row-local slots. Because the result type is + `GxT`, assignment does not derive this choice from result lane count. A + constant unit `source_group_stride` selects slots=8; non-unit or dynamic + stride selects slots=1 first, then the support query rejects dynamic or + unaligned row-local lowering when the target cannot materialize it. + +block8 group_load: + assigned result layout is deinterleaved=2/4 with block_elems=8 only when the + op carries the required constant stride and memory-safety proof. + +group_store: + consumes group_slots(G,K). Explicit output stride attrs/operands decide + whether slots=8 packed or slots=1 row-local stores are legal. + +group_reduce_add{f|i}: + define E = sizeof(accumulator T), VLaneElems = 32B / E, L = 256B / E, + S = N / G. T is the accumulator/reduce element type after any required + storage cast. + S=VLaneElems uses contiguous source/mask and group_slots(G, slots=8). + S=2*VLaneElems uses deinterleaved=2 source/mask and group_slots(G, slots=8). + S=4*VLaneElems uses deinterleaved=4 source/mask and group_slots(G, slots=8). + S>=L && S%L==0 uses contiguous source/mask and group_slots(G, slots=1). + +group_broadcast: + consumes group_slots(G,K) and produces one assigned dense layout. If another + consumer wants a different dense layout, assignment inserts ensure_layout. + Optimization may clone/rematerialize group_broadcast per use. + +extf/truncf: + contiguous f16/bf16 -> deinterleaved=2 f32 + contiguous f8-like -> deinterleaved=4 f32 + deinterleaved=2 f32 -> contiguous f16 + deinterleaved=4 f32 -> contiguous f8-like + group_slots(G, slots=1) f32 -> f16 remains a slot-preserving transform. + +extsi/extui/trunci: + contiguous i8/i16 -> deinterleaved i32 according to widening factor. + deinterleaved i32 -> contiguous i8/i16 according to narrowing factor. + packed group_slots integer width-changing cast is unsupported until a + slot-wise transform is represented explicitly. + +bitcast: + per-part vbitcast is valid when source/result layouts match, physical arity + matches, and every physical chunk carries the same logical bit footprint. + This includes contiguous, deinterleaved, and identical group_slots layouts. +``` + +`vmi-layout-fold`, rematerialization, sink/hoist, and private +function specialization passes consume explicit helper IR. They may replace +helpers with cheaper equivalent IR, but they must not introduce hidden lowering +plans that `vmi-to-vpto` has to rediscover from producer/user context. + +## 6. Layout Assignment Data Model + +### 6.1 Solver State + +```c++ +struct ValueLayoutState { + Value value; + Type logicalType; + std::optional chosen; + std::optional naturalLayout; + SmallVector useRequests; +}; + +struct UseRequest { + OpOperand *operand; + VMILayoutKey requestedLayout; + Operation *requestingOp; + bool hard; +}; +``` + +### 6.2 Collection Phase + +Walk the module and collect: + +```text +1. every VMI value +2. every VMI block argument +3. every VMI function argument/result +4. every VMI op with natural producer layouts or use-site layout requests +5. every branch/yield/call/return edge carrying VMI +``` + +Build SCCs over: + +```text +dataflow uses +region yields +loop iter_args +function call graph for private/internal functions +``` + +Public/external VMI function boundaries are rejected unless +`enablePublicVMIABI` is explicitly supported. + +Block arguments are first-class layout variables. Assignment must write the +chosen layout into the block argument type or specialized function signature. +`vmi-to-vpto` must never recover a block argument layout by walking to an +incoming branch, yield, or call operand. + +### 6.3 Constraint Generation + +Examples: + +```text +truncf f32->f16: + source request deinterleaved=2, block_elems=1 + result contiguous + +group_reduce S=16: + source request deinterleaved=2, block_elems=1 + result group_slots(G, slots=8) + +group_reduce S=32: + source request deinterleaved=4, block_elems=1 + result group_slots(G, slots=8) + +group_reduce S=64: + source request contiguous + result group_slots(G, slots=1) + +group_broadcast: + source request group_slots(G,K) + result receives one assigned dense layout + incompatible dense uses are represented by ensure_layout + +ordinary dense add/mul/select: + operands/results same dense layout + +group-slot add/mul: + operands/results same group_slots(G,K) + +ordinary store: + dense source required + group_slots source is illegal + +group_store: + source request group_slots(G,K) + +dhist: + acc/result request contiguous 256xui16 + source request contiguous Nxui8 + mask request contiguous b8 + +chist: + same layout requests as dhist + diagnostic unless CHISTv2 cumulative range semantics are classified +``` + +Baseline assignment does not perform consumer-driven adoption for performance. +It records natural producer layouts and hard use-site requests. If a request +does not match the assigned layout, the pass inserts an explicit helper at that +use. + +```text +natural layout producer: + extf/truncf, group_reduce, group_slot_load, group_load, dhist/chist when the + op itself carries a layout-producing contract + +layout equality producer: + dense add/mul/select and CFG-carried values tie operands/results but do not + pick a cheaper layout by cost +``` + +Memory legality constraints: + +```text +S=32 tail fast load: + requires full_footprint_readable + otherwise require gather fallback or diagnose + +compact S=12 logical S=16: + requires compact-row gather materialization + diagnose if gather fallback is disabled/missing +``` + +### 6.3.1 Request Builders + +Implement request generation as small per-op builders. The builders produce +natural layouts, use-site requests, equality constraints, and diagnostics; they +do not choose optimization plans. + +```text +buildStoreRequests: + ordinary store -> dense contiguous request + group_store -> group_slots(G,K) request plus stride/alignment capability + checks + +buildCastRequests: + extf f16->f32 -> source contiguous, result deinterleaved=2 + extf f8->f32 -> source contiguous, result deinterleaved=4 + truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous + truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous + group_slots slots=1 f32->f16 -> explicit slot-preserving transform + group_slots slots=8 width-changing cast -> diagnostic unless a packed + transform is explicitly represented + +buildGroupReduceRequests: + derive E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems -> contiguous source, group_slots(G,8) result + S=2*VLaneElems -> deinterleaved=2/block_elems=1 source, + group_slots(G,8) result + S=4*VLaneElems -> deinterleaved=4/block_elems=1 source, + group_slots(G,8) result + S>=L && S%L==0 -> contiguous source, group_slots(G,1) result + 8-bit storage must be cast to an accumulator type before this request builder + other S -> diagnostic unless an explicit fallback op/helper is enabled + +buildGroupMemoryRequests: + group_load S=16/S=32 with aligned constant stride -> natural block_elems=8 + group_load row-local full chunks -> natural contiguous + group_slot_load unit stride -> group_slots(G,8) + group_slot_load aligned row-local stride -> group_slots(G,1) + unsupported dynamic/unaligned grouped memory -> diagnostic + +buildElementwiseRequests: + dense add/mul/fma/min/max/select -> all dense operands/results share one + dense layout + group-slot add/mul/select -> all operands/results share one group_slots(G,K) + dense/group_slots mixing -> diagnostic unless an explicit group_broadcast or + group_store boundary exists + +buildMaskRequests: + mask layout follows each consuming data layout + predicate granularity follows each consuming element type + create_mask/create_group_mask produce one assigned mask layout and use + ensure_mask_layout / ensure_mask_granularity for incompatible uses + masked_store requests source layout, mask layout, and store predicate + granularity explicitly + +buildHistogramRequests: + dhist -> acc/result contiguous 256xui16, source contiguous Nxui8, + mask contiguous b8 + chist -> same layout requests, plus target capability diagnostic until + CHISTv2 high-range semantics are classified + do not create group_slots or group_reduce requests; histogram result bins are + selected by source values, not by lane/group position + +buildControlFlowRequests: + region yields, branch operands, loop iter_args, call operands, and returns + create equality requests on the carried VMI layout variable + +buildFunctionBoundaryRequests: + private/internal function argument/result layouts are materialized with + callee-entry/return-site helpers in baseline assignment; signature + specialization is an optimization pass + public/external VMI arguments/results diagnose unless enablePublicVMIABI has + a real ABI contract +``` + +Request builders must record the requesting op. Diagnostics and inserted +helpers are use-site operations, so the user can see which consumer forced a +layout. + +### 6.3.2 Optimization Producer Classes + +Baseline assignment does not use producer classes to solve conflicts. It +inserts helpers. Later optimization passes may classify producers to replace +helpers with cheaper equivalent IR. + +```text +cheap rematerializable producers: + load when address operands dominate the clone site, no intervening may-alias + write exists, and any shaped memory proof is preserved + broadcast + create_mask + create_group_mask + group_broadcast + group_slot_load when the same address/no-alias/proof conditions as load hold + and the memory access remains legal at the clone site + +layout-transparent producers: + add/sub/mul/fma/min/max/neg/abs + select + bitcast + integer bitwise and shift ops + +fixed-layout producers: + extf/truncf physical conversion layouts + group_load block-fragment layouts + group_reduce result group_slots + dhist/chist result contiguous 256xui16 and source/mask contiguous b8 contract + masked_load when the physical memory-safety proof fixes a full-read lowering +``` + +Optimization conflict policy: + +```text +cheap producer: + clone for each incompatible request when cloning does not duplicate a + side-effect, cross an aliasing write, or duplicate an illegal memory read + +layout-transparent producer: + merge into the consumer-requested equivalence class; insert materialization + only at incompatible uses + +fixed-layout producer: + use explicit helper materialization only; otherwise diagnose +``` + +These classes are not assignment constraints. They are rewrite preconditions +for passes that consume `ensure_layout` and decide whether the helper can be +folded, sunk, hoisted, or replaced by rematerialization. + +### 6.4 Solving And Rewriting + +Algorithm: + +```text +1. Collect natural layouts, use-site requests, equality constraints, and + memory-safety proofs. +2. Propagate equality constraints through SCCs. +3. Choose one deterministic assigned layout per value/equivalence class: + explicit user layout, then unique producer natural layout, then hard + non-contiguous layout, then contiguous. +4. For conflicting uses, insert ensure_layout / ensure_mask_layout / + ensure_mask_granularity at the use. +5. Emit diagnostics for unsupported semantic constraints or missing explicit + materialization/memory-safety proof. +6. Rewrite VMI result/block/function types with chosen layouts. +7. Insert helper ops with source/result layout attrs. +``` + +Rewrite invariants: + +```text +No VMI data/mask value after assignment has a null layout. +Any non-local choice is represented by op attrs, operand/result layouts, a +helper op, or an explicit diagnostic. Cloned/rematerialized producers may +appear only after later layout optimization passes. +Every ensure_* helper has an explicit supported materialization path or a +diagnostic. +Every function/call boundary carrying VMI is materialized, kept in an explicit +ABI contract, or diagnosed. +``` + +### 6.5 Rewrite Artifacts + +Assignment rewrites the IR so that later lowering has no hidden choices. + +```text +type rewrite: + every VMI data/mask result and block argument receives a layout attr + +ensure rewrite: + mismatched uses get pto.vmi.ensure_layout or ensure_mask_layout at the use + site, with source and target layouts visible in the types + +granularity rewrite: + one semantic mask used by f32 and f16 consumers gets + ensure_mask_granularity at the use site + +control-flow rewrite: + scf.if/scf.for yields and block arguments are rewritten to one agreed layout; + materialization is inserted before yield when branches differ + +function rewrite: + baseline private VMI functions get callee-entry/return-site ensure_layout; + signature specialization is an optimization pass + public/external VMI functions are diagnosed +``` + +Canonical assigned IR shape for a conflicting load: + +```text +%x = pto.vmi.load ... + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +pto.vmi.store %x_dense, ... +``` + +Optional future optimized IR shape for a cloned load with an explicit +safe-read/execution proof: + +```text +%x_s16 = pto.vmi.load ... + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32 = pto.vmi.load ... + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Canonical assigned IR shape for `group_broadcast` multi-use: + +```text +%b = pto.vmi.group_broadcast %slots + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%b_c = pto.vmi.ensure_layout %b + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +If the assigned IR does not have one of these explicit shapes, `vmi-to-vpto` +must reject it instead of attempting to recover the missing decision. + +### 6.6 Case-To-Implementation Closure Matrix + +The current case catalog is sufficient for the first implementation. No new +layout kind is justified by the supported endpoints. The implementation work +should instead close the following finite matrix. Each row names the request +builder that owns the decision, the assignment artifact that must appear in IR, +and the `vmi-to-vpto` contract. + +```text +case family builder / owner assignment artifact +3.1, 3.2, 3.3 dense casts buildCastRequests dense layout on each cast result +3.29 mask width split buildMaskRequests per-use mask granularity helper +3.31, 3.32 dense fanout conflict resolver cloned load or ensure_layout + +vmi-to-vpto contract: + consume only the assigned dense layouts. It may emit VCVT and dense + materialization, but it must not choose deinterleaved=2/4 by inspecting a + later truncf, store, or group_reduce user. +``` + +```text +case family builder / owner assignment artifact +3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous lowering +3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 layout +3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 layout +3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local lowering +3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks +3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout +3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks +3.26 grouped tail buildMaskRequests split grouped masks +3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values + +vmi-to-vpto contract: + lower each reduce from the current op's attrs, source/mask layout, result + group_slots layout. It must not walk to the load/group_load producer to + decide parity versus block8, row-local versus packed slots, or static versus + dynamic mask generation. +``` + +```text +case family builder / owner assignment artifact +3.56 full distribution hist buildHistogramRequests contiguous src/mask/acc/result +3.57 cumulative hist boundary buildHistogramRequests capability diagnostic or classified path + +vmi-to-vpto contract: + lower dhist from the current op and assigned layouts by carrying two physical + accumulator parts for bins 0..127 and 128..255. It must not expose the VPTO + #bin range selector on the VMI surface and must not model histogram as + group_reduce. chist remains rejected until the target records whether the + high-range cumulative result is global or range-local and, for range-local + behavior, until low-total materialization is explicit. +``` + +```text +case family builder / owner assignment artifact +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load layout +3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan +3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan +3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan +3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load layout +3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof +3.39 strided load fanout conflict resolver preserving layout or materialization + +vmi-to-vpto contract: + consume only explicit memory stride/alignment attrs, current op operands, + and layouts. It must not infer safe read/write placement from neighboring + compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes + stay diagnostics until a gather fallback is explicit in the current op. +``` + +```text +case family builder / owner assignment artifact +3.8 reduce->truncf->broadcast conflict resolver slot cast plus dense materialization +3.10 non-load S=32 producer buildElementwiseRequests transparent deinterleaved chain +3.17 broadcast deint consumer conflict resolver use-site group_broadcast layout +3.18 dense + reduce users conflict resolver ensure_layout; optional remat/fold +3.23 broadcast multi-user conflict resolver per-op group_broadcast layout +3.33 S=16 + S=32 users conflict resolver use-site materialization; optional cloned load +3.34 S=64 slots=1 cast buildCastRequests group_slot_cast layout +3.35 slots fanout buildElementwiseRequests same group_slots layout on users +3.36 scalar slots=8/slots=1 conflict resolver explicit slots=8/slots=1 producers +3.40 scalar dense + grouped conflict resolver ensure_layout; optional broadcast remat +3.41 incompatible fixed value conflict resolver diagnostic or ensure_layout + +vmi-to-vpto contract: + each op instance is already single-plan. The lowering pass never scans + sibling users to decide whether to clone, pack, broadcast, or materialize. +``` + +```text +case family builder / owner assignment artifact +3.21 S=32 rounded tail mask buildMaskRequests rounded vector plus mask +3.24 mask/select/store buildMaskRequests explicit mask layout/granularity +3.12 scf.if before reduce buildControlFlowRequests common yielded layout +3.20 group_slots scf.if buildControlFlowRequests common group_slots layout +3.22 scf.for carried value buildControlFlowRequests fixed-point iter_arg layout +3.25 function boundary buildFunctionBoundary specialized/internal boundary +3.42 loop accumulator buildControlFlowRequests loop-carried group_slots layout +3.43 call argument materialize buildFunctionBoundary callee-entry/return helper + +vmi-to-vpto contract: + block argument, region result, call operand, and function result layouts are + visible in types or helper ops. It must not inspect branch bodies, loop + bodies, callers, or callees to discover a layout. +``` + +```text +diagnostic family builder / owner required failure +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store path +3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast +3.11.2 S=32 unsafe tail buildMaskRequests missing full_footprint_readable/gather +3.13 slots=8 width cast buildCastRequests no packed slot cast transform +3.14 unsupported group size buildGroupReduceRequests no supported reduce layout/lowering +3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load path +3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan +3.19.2 invalid block_elems use conflict resolver no preserving materialization +3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback path +3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback + +vmi-to-vpto contract: + these cases must fail before or at the layout contract boundary with the + requesting op named. They must not be accepted by falling back to a generic + dense load, dense store, or producer/user inspection. +``` + +Additional cases are needed only when the scope changes: + +```text +stable gather fallback enabled: + add compact S=12 positive lowering and masked_load unsafe-tail positive + lowering before accepting either path. + +pack-to-slots=8 or unaligned row-local stores enabled: + add positive S=64 unit-stride group_store and reduce->pack->dense store cases. + +public VMI ABI enabled: + add public call/return ABI cases before removing the public-boundary + diagnostic. + +packed group-slot width cast enabled: + add slots=8 f32->f16 cast and downstream group_store/broadcast cases. +``` + +## 7. OneToN Type Conversion + +`vmi-to-vpto` should use OneToN conversion for VMI values. + +Conversion rules: + +```text +contiguous: + ceil(N / lanesPerVReg(T)) physical vregs + +deinterleaved=F: + F * ceil((N / F) / lanesPerVReg(T)) physical vregs + ordering: part-major, then chunk + +group_slots(G,K): + ceil(G / K) physical vregs + each vreg has logical slot lanes 0..K-1 live +``` + +Mask conversion: + +```text +mask layout follows data layout +mask granularity is selected from consumer element width: + f32/i32 -> b32 + f16/i16 -> b16 + f8/i8 -> b8 +``` + +If one logical mask is used by multiple widths, assignment inserts +`ensure_mask_granularity` or rematerializes the mask producer. + +## 8. VMI-to-VPTO Pattern Rules + +Each pattern uses: + +```text +op +op attrs and operand values +operand/result layouts +adaptor physical values +``` + +Each pattern rejects: + +```text +missing current-op proof for an otherwise unsafe memory lowering +missing target capability +unexpected group_slots dense consumer +``` + +Target local lowering matrix: + +```text +load, lowering=dense_load_norm: + result layout contiguous + emits pto.vlds / pto.vsts NORM paths + covers dense store users and full-chunk row-local reduce input + +load, lowering=load_dintlv2: + result layout deinterleaved=2, block_elems=1 + emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization + covers f32->f16, S=16 parity reduce, f16->f32 widened values + +load, lowering=load_dintlv4: + result layout deinterleaved=4, block_elems=1 + emits two vldsx2 DINTLV_B32 plus vdintlv + covers f32->f8, S=32 dintlv4 reduce + +group_load, lowering=s16_group_load_block8_unit_stride: + result layout deinterleaved=2, block_elems=8 + emits vldsx2/BDINTLV for 8 rows of 16xf32 + covers compact logical S=16 when source_group_stride == 16 + +group_load, lowering=s16_group_load_block8_stride: + result layout deinterleaved=2, block_elems=8 + emits two vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, lowering=s32_group_load_block8_stride: + result layout deinterleaved=4, block_elems=8 + emits four vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, lowering=group_load_contiguous_chunks: + result layout contiguous + emits one vlds per physical group chunk using row_stride address arithmetic + covers the currently implemented full-chunk row-local group_load path + +group_reduce_add{f|i}, lowering=one_vlane_reduce_contiguous: + consumes contiguous accumulator type T with group size VLaneElems(T) + produces group_slots(G, slots=8) + emits one vcgadd + +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: + consumes deinterleaved=2, block_elems=1 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: + consumes deinterleaved=2, block_elems=8 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: + consumes deinterleaved=4, block_elems=1 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_add{f|i}, lowering=four_vlane_reduce_block8_stride: + consumes deinterleaved=4, block_elems=8 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_add{f|i}, lowering=full_chunk_reduce_row_local: + consumes contiguous accumulator type T with group size that is a multiple of + one physical chunk L(T) + produces group_slots(G, slots=1) + target lowering emits per-row vcgadd plus vcadd; the current prototype uses + the existing row-local VCADD/VADD/VSEL sequence while preserving the same + group_slots(G, slots=1) value contract + +dhist, lowering=full_256bin_histogram: + consumes contiguous Nxui8 source and contiguous b8 mask + consumes/produces contiguous 256xui16 accumulator/result + physical result parts are [bins 0..127, bins 128..255] + emits one low-range and one high-range histogram update for each 256-lane + source chunk + final partial source chunks require an explicit valid-lane b8 mask + +chist, lowering=capability_gated_cumulative_histogram: + uses the same layout shape as dhist + rejects until target capability classifies CHISTv2 high-range cumulative + semantics and any required low-total correction materialization is explicit + +group_slot_load, lowering=group_slot_load_slots8_unit_stride: + result group_slots(G, slots=8) + requires source_group_stride == 1 + emits one packed vsldb load + +group_slot_load, lowering=group_slot_load_slots1_row_local: + result group_slots(G, slots=1) + supports aligned non-unit source_group_stride + requires constant positive source_group_stride divisible by 256 / elementBits + emits one lane-0 vsldb per group + +group_broadcast, lowering=group_broadcast_slots8_vselr: + source group_slots(G, slots=8) + result dense layout selected per use + emits vselr using assigned result layout + +group_broadcast, lowering=group_broadcast_slots1_vselr: + source group_slots(G, slots=1) + result dense layout selected per use + emits vdup/vselr row-local materialization + +truncf, lowering=group_slot_cast_slots1_f32_to_f16: + source/result group_slots(G, slots=1) + emits one lane-0 vcvt per group slot block + rejects packed slots=8 unless slot-preserving cast support exists +``` + +The target matrix is the implementation contract. The staged status below +records how much of that contract the current prototype has already enforced. + +Current staged implementation status: + +```text +group_slot_load: + vmi-to-vpto lowers from #pto.vmi.layout + and source_group_stride. + +group_reduce_addf: + explicit slots=8 VCGADD lowering is selected from contiguous source/mask + layout, slots=8 result layout, num_groups, and reassoc. + S=16 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout; vmi-to-vpto lowers through two + VCGADDs plus a PAT_VL8 VADD per packed result block. + S=32 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout; vmi-to-vpto lowers through four + VCGADDs plus a PAT_VL8 VADD tree per packed result block. + Full-chunk row-local assignment, including S=64 and S=256 f32 cases, uses + #pto.vmi.layout and has focused + layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic + VCADD row-local lowering is selected locally from the current op attrs and + assigned layouts. + group_reduce_addi is implemented for i32 accumulator values. i8/i16 storage + must be widened explicitly before grouped reduction because narrow integer + reduction instructions widen their result. + +group_broadcast: + explicit slots=8/1 source layouts select + packed or row-local VSELR lowerings locally. Deinterleaved block-fragment + results use the result layout block_elems as the local vselr selection group, + so + `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each + 32B row fragment. VSELR index vectors are materialized per physical result + chunk. For small-group results, layout assignment has already fixed the + result layout, and vmi-to-vpto computes: + `firstGroup = first logical group covered by this result chunk`, + `sourceChunk = firstGroup / slots`, and + `baseGroupSlot = firstGroup % slots`. The generated index vector selects + `baseGroupSlot .. baseGroupSlot + groupsPerResultChunk - 1`; it must not be + reused across result chunks. + +group_load: + contiguous full-chunk path is selected from a contiguous result layout. + S=16/S=32 block-aligned strided loads are selected from + #pto.vmi.layout, and lower to one + vsldb per 32B row fragment and physical chunk. The explicit block8 support + is checked by pto-validate-vmi-layout-ir before vmi-to-vpto. + The dedicated S=16 unit-stride vldsx2/BDINTLV lowering remains a local + peephole target. + S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned + row_stride is rejected by vmi-layout-assignment because the stable gather + fallback is not implemented. + +truncf group-slot cast: + layout assignment and vmi-to-vpto support group_slots(G, slots=1) + f32 -> f16 from source/result layouts and element widths. The reduce->truncf + -> group_store slots=1 flow has focused lit coverage and no longer relies on + vmi-to-vpto inspecting the truncf producer. + +group_store: + row-local group_slots(G, slots=1) lowering is implemented as one lane-0 + vsts per group for packed unit-stride output, or as one 1PT store per group + for non-unit row strides. The packed path is covered by the + reduce->truncf->group_store lit case, while the point-store path is covered + by `test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto`. + Packed group_slots(G, slots=8) group_store is implemented only when + num_groups is a multiple of 8 and row_stride is constant 1; it emits one + PAT_VL8 store per packed slot block. Non-unit packed group stores remain a + design target unless a strided packed-lane store lowering is made explicit. +``` + +Current implementation contract for type-generic grouped reduction: + +```text +ODS/verifiers: + pto.vmi.group_reduce_addi is the integer counterpart to group_reduce_addf. + group_reduce_addi accepts i32 accumulator element types; i8/i16 direct + grouped reduction is rejected with a diagnostic that points users to + extsi/extui. + extsi/extui/trunci carry integer signedness across storage/accumulator + boundaries without overloading add semantics. + +Layout assignment: + compute VLaneElems and L from the accumulator/reduce element type: + VLaneElems = 32B / sizeof(accumulator T) + L = 256B / sizeof(accumulator T) + use the same S formula for f16/f32/i32 once the typed reduce op and target + capability say the type is legal. + route f8 storage through extf to f32 before group_reduce_addf. + route i8/i16 storage through extsi/extui to i32 before group_reduce_addi. + route integer narrowing to i8 through trunci; direct i8 compute remains + illegal unless target capability and explicit op semantics define that + lowering. + diagnose direct f8/i8 compute use with a message that points at the offending + op and suggests inserting the explicit cast when the op is meant to consume + storage data. + +Layout fact helpers: + replace f32-shaped checks with width-parametric group-reduce classifiers: + one_vlane_reduce layout fact + two_vlane_reduce_deinterleaved layout fact + four_vlane_reduce_deinterleaved layout fact + full_chunk_row_local_reduce layout fact + key legality on accumulator byte width, source/mask layout, result + group_slots layout, num_groups, and target instruction capability. + +VMI-to-VPTO: + lower group_reduce_addi through the same VCGADD/VADD skeleton used for + floating-point where the target supports the integer accumulator type. + materialize integer casts explicitly before reduction; direct i8 group reduce + and direct i16 group reduce must not silently become a widening reduction in + this pass. + keep VPTO lowering local: it consumes assigned layouts and current-op + attrs/operands, but does not invent a new global layout plan. + +Tests: + cover f16 direct and i16-storage-to-i32 grouped reductions. + add i32 S=8/S=16/S=32/S=64 group-reduce cases. + add f8 storage -> extf -> f32 group_reduce_addf cases. + add i8/i16 storage -> extsi/extui -> i32 group_reduce_addi cases. + add invalid direct f8/i8/i16 grouped-reduce diagnostics. +``` + +Examples: + +```text +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: + consume deinterleaved=2, block_elems=1 + emit two VCGADDs and one VADD + +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: + consume deinterleaved=2, block_elems=8 + emit two VCGADDs and one VADD + +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: + consume deinterleaved=4 + emit four VCGADDs and reduction tree + +group_broadcast: + consume group_slots + emit VSELR or VDUP depending slots and target dense layout + +group_slot_load slots=8: + emit one packed block load for unit stride + +group_slot_load slots=1: + emit row-local lane-0 loads for constant positive 32B-aligned strides +``` + +## 9. Validation Passes + +### 9.1 Surface Validation + +Before assignment: + +```text +VMI types may omit layout. +VPTO physical op must not consume VMI values. +Public/external VMI function ABI rejected unless enabled. +Unsupported vector-to-scalar extract rejected. +``` + +### 9.2 Layout Validation + +After assignment: + +```text +Every VMI value has layout. +Every VMI mask has layout and granularity plan. +Every lowering choice is locally deterministic or explicit in attrs/layouts. +Every ensure_* helper has a materialization path. +Every control-flow edge has matching VMI layouts. +``` + +### 9.3 `vmi-to-vpto` Context Read Audit + +`vmi-to-vpto` may still read defining ops in narrowly scoped cases that do not +select a layout or plan: + +```text +allowed: + arith.constant for the current op's scalar operands + create_mask/create_group_mask internals when lowering that mask op itself + ensure_mask_layout / ensure_mask_granularity stripping for static mask facts + memref.subview only to improve an already-failed non-identity memref + diagnostic + +not allowed: + walking from a consumer to a producer to decide a lowering + walking from a consumer to a mask producer to decide whether a lowering is legal + inspecting users to choose a result layout or materialization + recovering full_footprint_readable from surrounding MTE/caller context +``` + +Current audit result: + +```text +3.44 partial S=32 create_group_mask: + assignment writes explicit contiguous and deinterleaved mask values. When + lowering the deinterleaved create_group_mask itself, vmi-to-vpto first + materializes contiguous grouped predicate chunks and then applies predicate + pdintlv in the same tree shape as the data vdintlv. It still does not walk + from group_reduce_addf to the mask defining op to choose or reject lowering. + The dynamic active_elems_per_group form is also op-local: vmi-to-vpto lowers + contiguous chunks with vci/vshrs/vshls/vsub/vcmps, then uses the same + predicate pdintlv tree for S=32 deinterleaved masks. + +masked_load: + direct lowering is load + vsel. It does not inspect the mask producer to + choose a different load form; memory safety is provided by full physical + chunks or shaped memref proof. + +memref.subview: + mentioned only after identity lane-to-address planning fails. It is not used + to recover a hidden base/stride lowering. +``` + +## 10. Diagnostics + +Implement diagnostics with stable prefixes: + +```text +VMI-LAYOUT-CONTRACT +VMI-UNSUPPORTED-PLAN +VMI-MISSING-CAPABILITY +VMI-PUBLIC-ABI +VMI-MASK-GRANULARITY +VMI-CONTROL-FLOW-LAYOUT +``` + +Minimum diagnostic payload: + +```text +op name +logical type +actual layout +requested layout +selected/missing support path +recommended rewrite or option +``` + +Example: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the selected + strided group_load layout. Register a rematerialization or preserving + materialization path, or avoid consuming this block-loaded value with truncf. +``` + +## 11. Test And Simulator Acceptance + +Each numbered endpoint in `vmi-layout-lowering-cases.md` should become: + +```text +1. a layout-assignment lit test +2. a vmi-to-vpto lit test +3. a simulator case when the VPTO sequence is supported by the current backend +4. a diagnostic lit test when the case is explicitly unsupported +``` + +Repository locations: + +```text +test/lit/vmi/ +test/vpto/cases/vmi/ +``` + +The current repository uses descriptive flat lit names rather than +case-numbered subdirectories. New tests should follow the existing prefixes: + +```text +vmi_layout_assignment_.pto +vmi_to_vpto_.pto +/kernel.pto +``` + +The case number should still be recoverable from the coverage table in this +document and from the corresponding section in `vmi-layout-lowering-cases.md`. + +### 11.1 Layout Assignment Checks + +Each positive layout-assignment test must check: + +```text +assigned data layouts +assigned mask layouts +assigned op attrs +direct vmi-to-vpto local lowering +inserted ensure_layout/rematerialized producers +control-flow/function signature specialization +``` + +Negative tests check diagnostic text. + +### 11.2 VMI-to-VPTO Checks + +Each positive vmi-to-vpto test must check: + +```text +no pto.vmi ops remain +VPTO op sequence matches the case lowering +physical value arity and ordering are correct +mask granularity is correct +stores preserve observable logical memory order +``` + +### 11.3 Simulator Checks + +Simulator cases should compare final memory against the memory result written in +the case catalog. + +Current broad runtime sweep: + +```text +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-final CASE_PREFIX='vmi/' JOBS=4 \ + test/vpto/scripts/run_host_vpto_validation_parallel.sh + +TOTAL_CASES=47 +PASS=47 FAIL=0 +summary: .tmp/vmi-runtime-batch-final/parallel-summary.tsv +result: all summary entries are PASS +``` + +The `find: Permission denied` messages printed while discovering CANN simulator +paths are environment noise and are not treated as simulator failures. + +Required groups: + +```text +dense conversion: + 3.1, 3.2, 3.3, 3.31, 3.32 + +group reduce: + 3.4, 3.5.1, 3.5.2, 3.5.3 + 3.6.1, 3.6.2, 3.6.3 + 3.7.1, 3.7.2, 3.7.3 + 3.7.4 diagnostic + +layout/rematerialization: + 3.8, 3.10, 3.17, 3.18, 3.19.1, 3.22, 3.23, 3.31, + 3.32, 3.33, 3.34, 3.35, 3.36, 3.38, 3.40, 3.41 + +mask/tail: + 3.11.1, 3.15.1, 3.15.2, 3.21, 3.24, 3.26, 3.29, + 3.30, 3.44, 3.45 + +strided/group-slot memory: + 3.27, 3.28, 3.37, 3.39 + +function/control-flow: + 3.12, 3.20, 3.22, 3.25.1, 3.42, 3.43 + +histogram: + 3.56 positive dhist layout/lowering and simulator case when backend support + is enabled + 3.57 diagnostic chist case until CHISTv2 range semantics are classified +``` + +Aggregate catalog headings are covered through their endpoint subcases: + +```text +3.11 partial tail groups: + 3.11.1 positive S=64 active-row tail + 3.11.2 diagnostic S=32 tail without full_footprint_readable + +3.15 compact S=12 written as logical S=16: + 3.15.1 positive source row stride 16 + 3.15.2 positive source row stride greater than 16 + 3.15.3 diagnostic compact source row stride 12 + +3.16 group_slot_load layout contract: + 3.16.1 packed slots=8 positive and non-unit-stride diagnostic + 3.16.2 row-local slots=1 positive plus dynamic/unaligned diagnostics + +3.25 function boundary layout specialization: + 3.25.1 private/internal boundary lit and runtime coverage + 3.25.2 public/external boundary diagnostics +``` + +Current coverage audit result: + +```text +SIM-backed positive endpoints: + 3.1, 3.2, 3.3, 3.4, 3.5.1, 3.5.2, 3.5.3, + 3.6.1, 3.6.2, 3.6.3, 3.7.1, 3.7.2, 3.7.3, + 3.8, 3.10, 3.11.1, 3.12, 3.15.1, 3.15.2, + 3.16.1 positive, 3.16.2 positive, 3.17, 3.18, + 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.25.1, 3.26, + 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, + 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, + 3.40, 3.41, 3.42, 3.43, 3.44, 3.45 + +diagnostic endpoints: + 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, + 3.16.1 non-unit slots=8 source stride, + 3.16.2 dynamic/unaligned slots=1 source stride, + 3.19.2, 3.25.2, 3.27 unaligned source_group_stride, + 3.30 unsafe masked_load tail + +repository evidence: + all concrete lit/runtime paths listed below exist + all 47 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + golden.py, and compare.py + latest broad VMI runtime sweep passed: PASS=47 FAIL=0 + latest full VMI lit sweep passed: 350/350 + this historical sweep predates 3.56/3.57; histogram endpoints require new + lit/SIM or diagnostic tests before they can be counted as implemented +``` + +Current checked-in coverage for 3.3 dense f8->f32->compute->f8: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto + +runtime SIM: + test/vpto/cases/vmi/f8-compute-f8 +``` + +Current checked-in coverage for 3.1/3.2 dense f16/f32 conversion stores: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce + test/vpto/cases/vmi/quant-f32-to-f16-tail +``` + +Current checked-in coverage for basic packed group_reduce -> group_store paths +for 3.4, 3.5.1, and 3.6.1: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-basic-store +``` + +Current checked-in coverage for S=16 group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store +``` + +Current checked-in coverage for 3.35 group_slots fanout to direct group_store +and group_broadcast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-fanout-store-broadcast +``` + +Current checked-in coverage for 3.8 `group_reduce -> group_broadcast -> +truncf -> dense store` and 3.17 `group_broadcast` feeding a +deinterleaved consumer: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store +``` + +Current checked-in coverage for 3.18 one dense value with dense and +group-reduce consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/dense-group-reduce-multi-consumer +``` + +Current checked-in coverage for 3.10 non-load producer feeding S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-add-bias-store +``` + +Current checked-in coverage for 3.23 group_broadcast with multiple dense +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/group-broadcast-multi-consumer +``` + +Current checked-in coverage for S=32 contiguous group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store +``` + +Current checked-in coverage for 3.21 S=32 tail with a statically safe +full-read source: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store + This case has `ptoas.flags` with `--enable-vmi`, because the partial pointer + load must run through layout assignment before VPTO/LLVM emission. +``` + +Current checked-in coverage for 3.44 masked_load grouped tail feeding S=32 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` + +Current checked-in coverage for 3.45 dynamic S=32 `create_group_mask`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto + +runtime SIM: + test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store + +runtime scalar source: + active_cols is passed as a kernel i32 scalar argument and cast to index inside + vecscope before pto.vmi.create_group_mask. This is an explicit scalar ABI, + not a value recovered by vmi-to-vpto from producer/consumer context. +``` + +Current checked-in runtime coverage for 3.12 control-flow join before S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_cf_branch.pto + test/lit/vmi/vmi_to_vpto_cf_branch.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-cf-join-store +``` + +Current checked-in runtime coverage for 3.20 `group_slots` control-flow join: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-cf-join-store +``` + +Current checked-in runtime coverage for 3.22 `scf.for` loop-carried VMI layout: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_scf_for.pto + test/lit/vmi/vmi_to_vpto_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/scf-for-loop-carried-store +``` + +Current checked-in runtime coverage for 3.42 `group_slots` `scf.for` +loop-carried accumulator: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-scf-for-store +``` + +Current checked-in coverage for 3.25.1 private function result boundary: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +implementation note: + after vmi-to-vpto physicalizes the private helper, ptoas inlines private + single-block helpers whose signatures contain !pto.vreg or !pto.mask. This + happens before VPTO vecscope/backend emission, so physical vector values do + not escape through a function return. +``` + +Current checked-in coverage for 3.43 internal function argument boundary +materialization: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-argument-boundary-store + +implementation note: + private physical helper inlining also covers void helper calls with physical + VMI arguments, so the backend no longer sees a physical VPTO vector function + ABI for this internal boundary. +``` + +Current checked-in coverage for packed group-slot RHS elementwise continuations +for 3.5.3 and 3.6.2: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-slot-add-store +``` + +Current checked-in coverage for S=64 row-local group broadcast continuation +with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store +``` + +Current checked-in coverage for S=64 active-row tail with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-tail-store +``` + +The companion lit case for non-unit slots=1 point-store lowering is: + +```text +test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto +``` + +Current checked-in coverage for S=64 row-local group-slot RHS elementwise +continuation with aligned source_group_stride and aligned output row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-slot-add-store +``` + +Current checked-in coverage for 3.34 S=64 `slots = 1` group-slot f32->f16 cast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-truncf-store +``` + +The companion negative lit cases for dynamic or unaligned `%c2` slots=1 +group_slot_load, and non-unit `slots = 8` group_slot_load, are: + +```text +test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +``` + +Current checked-in coverage for the strided block-load cases: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-load-s16-stride-store + test/vpto/cases/vmi/group-load-s32-stride-store + test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce +``` + +Current checked-in coverage for grouped mask S=16 tail: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto + test/lit/vmi/vmi_create_group_mask_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store +``` + +Current checked-in coverage for 3.24 mask/select/masked-store semantics: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_select_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-select-store +``` + +Current checked-in coverage for 3.29 one semantic mask with f32 and f16 +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-granularity-f32-f16-store +``` + +Current checked-in coverage for 3.31 f16->f32 feeding dense store and S=16 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce +``` + +Current checked-in lit coverage for the first `vmi-layout-fold` +optimization is: + +```text +test/lit/vmi/vmi_layout_fold_store.pto +test/lit/vmi/vmi_layout_fold_masked_store.pto +test/lit/vmi/vmi_layout_fold_deint4.pto +``` + +Current checked-in lit coverage for the first `vmi-layout-rematerialize` +optimization is: + +```text +test/lit/vmi/vmi_layout_rematerialize_data.pto +test/lit/vmi/vmi_layout_rematerialize_mask.pto +``` + +Current checked-in lit coverage for the first +`vmi-layout-sink-materialization` optimization is: + +```text +test/lit/vmi/vmi_layout_sink_materialization_binary.pto // unary, binary, fma, cmp, and select data ops +test/lit/vmi/vmi_layout_sink_materialization_mask.pto +``` + +Current checked-in lit coverage for canonicalized VMI control-flow restoration is: + +```text +test/lit/vmi/vmi_legalize_arith_select.pto +test/lit/vmi/vmi_ptoas_cli_control_flow.pto +``` + +Current checked-in lit coverage for the first semantic local-lowering layout gate +is: + +```text +test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +test/lit/vmi/vmi_layout_gate_store_support_invalid.pto +test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto +test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto +``` + +Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local +lowering is: + +```text +test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto +test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto +``` + +Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/f32-to-f8-store-reduce +``` + +Current checked-in coverage for multi-tile group-slot arity: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-multitile-store +``` + +Current checked-in coverage for 3.40 scalar broadcast feeding dense and grouped +users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/broadcast-dense-group-users +``` + +Current checked-in coverage for 3.41 non-rematerializable `masked_load` feeding +dense and grouped users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-dense-group-users +``` + +Diagnostic-only cases: + +```text +3.9 dense store of group slots +3.11.2 S=32 tail without full_footprint_readable +3.7.4 S=64 slots=1 group_store with unit output stride +3.13 packed group-slot f32 -> f16 cast +3.14 unsupported group size +3.15.3 compact source row stride 12 +3.16.1 group_slot_load slots=8 non-unit stride +3.16.2 group_slot_load slots=1 dynamic or unaligned stride +3.27 S=32 source_group_stride not divisible by 8 f32 elements +3.19.2 block_elems=8 value consumed by truncf without materialization path +3.25.2 public/external VMI boundary +3.30 unsafe masked_load tail without stable masked/gather fallback +``` + +Current checked-in diagnostic coverage for 3.9/3.13/3.14: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto + test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +``` + +Current checked-in diagnostic coverage for the remaining non-SIM diagnostic +entries: + +```text +lit: + test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto + test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto + test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto + test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto + test/lit/vmi/vmi_ptoas_public_abi_invalid.pto + test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto + test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto + test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto +``` + +Capability boundaries and runtime evidence notes: + +```text +private physical function ABI: + 3.25.1 and 3.43 runtime coverage is closed for private/internal single-block + helpers by inlining private physical VMI helpers after vmi-to-vpto and before + VPTO vecscope/backend emission. Public/external VMI boundaries are still + rejected until a stable VMI ABI is defined. + +memory-proof runtime coverage: + 3.21 S=32 rounded tail-mask coverage is provided by a runtime case that loads + a full 256xf32 UB pointer vector and uses a 192-lane mask to define the active + logical rows. No surrounding MTE, caller/body context, or producer/user scan is + inspected to justify partial pointer reads. +``` + +## 12. Implementation Slices + +### Slice 1: IR Skeleton And Verifiers + +```text +layout attrs +vmi.vreg/vmi.mask types +surface op definitions +surface/layout validators +``` + +### Slice 2: Straight-Line Dense Assignment/Lowering + +```text +3.1 f16->f32->store +3.2 f32->f16->store +3.3 f8->f32->compute->f8 +``` + +### Slice 3: Group Slots And Reductions + +```text +3.4 S=8 +3.5 S=16 parity/block8 +3.6 S=32 +3.7 S=64 +group_slot_load +group_broadcast +group_store +``` + +### Slice 4: Layout Conflicts And Materialization + +```text +3.8 cast commute through group_broadcast +3.18 dense/group-reduce multi-consumer +3.19 block_elems layout selection +3.23 group_broadcast multi-consumer +3.32 f32 feeding f8 store and S=32 reduce +3.33 S=16/S=32 reduce multi-consumer rematerialization +3.34 slots=1 group-slot f32->f16 cast +3.35 group_slots fanout to group_store and group_broadcast +3.36 group_slot_load expressed as explicit slots=8/slots=1 producers +3.38 multi-tile group_slots arity +3.40 scalar broadcast materialized for dense/grouped users +3.41 non-rematerializable value with ensure_layout +``` + +### Slice 5: Masks, Tail, And Memory Legality + +```text +create_mask +create_group_mask +masked_store +safe full-read proof +compact/gather diagnostics +mask granularity per use +group_load stride greater than group size +group_slot_load slots=1 aligned non-unit stride plus dynamic/unaligned diagnostic +group_store slots=1 non-unit output stride +strided group_load feeding broadcast and a second reduce +masked_load grouped tail feeding S=32 reduce +``` + +### Slice 6: Control Flow And Functions + +```text +scf.if +scf.for +group_slots across control flow +group_slots loop-carried accumulator +internal function specialization +internal function argument boundary materialization +public ABI diagnostic +``` + +### Slice 7: Histogram + +```text +3.56 full 256-bin dhist logical op +3.57 chist semantic capability diagnostic +``` + +## 13. Completion Checklist + +Current evidence for the case-catalog objective: + +```text +1. every pre-histogram catalog endpoint is mapped in section 6.6 to an + assignment owner, assignment artifact, and vmi-to-vpto contract +2. every pre-histogram SIM-backed positive endpoint is listed in section 11.3 + and has a checked-in runtime case directory +3. every existing runtime case directory contains kernel.pto, launch.cpp, + main.cpp, golden.py, and compare.py +4. the latest historical broad VMI runtime sweep passed: PASS=47 FAIL=0 +5. the latest historical full VMI lit sweep passed: 350/350 +6. every pre-histogram unsupported endpoint listed in section 11.3 has a + diagnostic lit test +7. vmi-to-vpto decisions are represented by current-op attrs/operands, + assigned layouts, helper ops, rematerialization, or diagnostics +8. no separate lowering-plan string attr is emitted or consumed +9. release docs remain untouched; this is still a design/implementation plan + under docs/designs +10. new histogram endpoints 3.56/3.57 are mapped in section 6.6, but their + implementation evidence is intentionally pending new lit/SIM or diagnostic + tests +``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md new file mode 100644 index 0000000000..98988eb667 --- /dev/null +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -0,0 +1,1131 @@ +# VMI Layout Assignment And Lowering Design + +本文是新的 VMI layout assignment / lowering 设计文档。它只以 +`docs/designs/vmi-layout-lowering-cases.md` 为 source of truth,不继承早期 +VMI 草稿的 layout 设计,以避免旧上下文污染。 + +目标: + +```text +VMI surface IR + -> pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-fold // optional optimization over remat-exposed helpers + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> layout-assigned and optimized VMI IR + -> vmi-to-vpto + -> VPTO IR +``` + +核心验收约束: + +```text +vmi-to-vpto 不允许通过上下文猜 lowering。 + +任何需要 producer/consumer/control-flow/memory/mask 上下文才能决定的事, +必须在 vmi-layout-assignment 或后续 VMI layout optimization 阶段变成显式 IR: + +1. vmi.vreg/vmi.mask 的 layout +2. current-op attrs/operands that make the local lowering deterministic +3. use-site ensure_layout / ensure_mask_layout / ensure_mask_granularity +4. rematerialized or cloned producer +5. target capability diagnostic +``` + +## 0. Hard Legalization And Optimization Boundary + +Layout assignment is a stage, not necessarily one monolithic pass. The design +separates correctness from optimization: + +```text +hard legalization: + produces legal layout-assigned VMI IR for all supported semantics + inserts conservative ensure_* helpers at incompatible uses + may choose a simple canonical layout even when a fused consumer lowering exists + must diagnose unsupported semantics before vmi-to-vpto has to guess + +layout optimization: + rewrites already legal VMI IR into cheaper but equivalent VMI IR + may fold ensure_layout into a layout-aware consumer + may clone/rematerialize cheap producers for different use-site layouts + may sink or hoist layout materialization through pure elementwise chains + may specialize private VMI function signatures +``` + +The driver currently runs MLIR's normal `canonicalize` and `cse` between these +VMI-specific passes. They are allowed to clean up trivially unused helpers, +merge identical rematerialized producers, and expose simpler use-def shapes. +They are not a source of hidden lowering information; after every optimization, +the IR must still carry enough local information for `vmi-to-vpto`. + +The baseline hard pass may emit: + +```text +%x_c = pto.vmi.ensure_layout %x : deinterleaved=2 -> contiguous +pto.vmi.store %x_c +``` + +A later optimization may replace that use with: + +```text +pto.vmi.store %x : deinterleaved=2 +``` + +only if the store op itself has a local deterministic lowering for preserving the +same row-major memory effect, such as a layout-aware `vstsx2 INTLV` lowering. +Both forms are semantically complete. The second form is an optimization, not +a hard requirement for correctness. + +## 1. Source Case Coverage + +设计必须覆盖 case catalog 中的端到端场景: + +```text +dense cast: + f16 -> f32 -> store + f32 -> f16 -> store + f8 -> f32 -> compute -> f8 + f8 -> f32 accumulator -> group_reduce_addf + i8/i16 -> signed/unsigned integer cast to i32 accumulator + -> group_reduce_addi + f8/i8 appear as cast source or cast destination at compute boundaries + integer narrowing back to i8 is an explicit cast, not implicit arithmetic + f16 -> f32 shared by dense store and S=16 reduce + f32 shared by f8 store and S=32 reduce + +group reduce: + 32-bit accumulator: S=8, S=16, S=32, S=64 + 16-bit accumulator: S=16, S=32, S=64, S=128 + 8-bit storage reduces only through an explicit accumulator cast + reduce -> group_store + reduce -> group_slot_load/elemwise -> group_store + reduce -> group_broadcast -> elemwise -> reduce -> store + one group_slots result fanning out to group_store and group_broadcast + grouped tail -> broadcast -> reduce -> store + +layout conflict: + one value with dense and group-reduce consumers + one value with S=16 and S=32 group-reduce consumers + one scalar broadcast materialized for dense and grouped users, with optional rematerialization + one non-rematerializable value materialized with use-site ensure_layout + one scalar group-slot source expressed as explicit slots=8 and slots=1 producers + S=16 block_elems=1/8 layout selection + dense consumer of group_slots diagnostic + packed group-slot width-changing cast diagnostic + S=64 slots=1 group-slot width-changing cast + +control flow: + scf.if before group_reduce + group_slots across scf.if + scf.for loop-carried layout fixed point + group_slots as scf.for loop-carried accumulator + internal function boundary specialization + internal function argument boundary materialization + public/external VMI ABI diagnostic + +mask and tail: + prefix mask + group-periodic mask + dynamic group-periodic mask + masked_load tail with explicit passthrough instead of padding + masked_load grouped tail feeding group_reduce + masked select/store + one semantic mask used by multiple predicate granularities + S=32 tail with and without full_footprint_readable + compact S=12 diagnostic + +strided memory: + group_load source stride greater than logical group size + strided group_load feeding broadcast and a second group_reduce + group_slot_load slots=1 with non-unit source stride + group_store slots=1 with non-unit output stride + +value-indexed accumulation: + full 256-bin distribution histogram over Nxui8 source lanes + VPTO low/high bin range split hidden behind one logical 256xui16 VMI result + cumulative histogram is a semantic boundary until CHISTv2 range semantics are verified +``` + +### 1.1 Case-Set Sufficiency + +The current case set is sufficient to define the first implementation of layout +assignment and lowering. It covers every decision axis that has changed the +design so far: + +```text +physical dense layout: + contiguous, deinterleaved=2/4, block_elems=1/8 + +group-slot result layout: + group_slots(G, slots=8) for packed VCG results + group_slots(G, slots=1) for row-local S=64 results + +producer-driven layout: + load, group_load, group_slot_load, broadcast, create_mask, + create_group_mask + +consumer-driven pressure: + dense store, group_reduce, group_store, group_broadcast, truncf, + elementwise/select, masked_load/masked_store + +conflict resolution: + explicit ensure_layout, explicit ensure_mask_layout, explicit diagnostics + optimization passes may later replace the helpers with rematerialization or + layout-aware consumers + +control-flow propagation: + scf.if, scf.for iter_args/results, internal/private function boundaries, + public ABI rejection + +memory legality: + full_footprint_readable proof, grouped masks, predicate granularity, aligned + strided group memory, stable gather diagnostic + +value-indexed accumulation: + histogram source/result shape, b8 source mask, and fixed low/high VPTO bin + split for a logical 256-bin result +``` + +No extra layout kind should be added unless a new case proves that the existing +layouts and explicit helper contracts cannot express the logical behavior. The remaining open +items are not missing layout semantics: + +```text +dynamic active_elems_per_group runtime source: + create_group_mask layout lowering is defined and has both lit and SIM + coverage. The supported runtime source is a kernel scalar argument cast to + index inside vecscope; vmi-to-vpto does not recover this value from GM/UB + scalar loads or surrounding context. + +private vector function runtime: + private/internal single-block helpers are runtime-covered by ptoas inlining + private physical VMI helpers after vmi-to-vpto and before VPTO vecscope/backend + emission. This is a post-physicalization backend hygiene step; vmi-to-vpto + still lowers only from assigned layouts and helper ops. + +diagnostic-only cases: + compact S=12 gather fallback, packed slots=8 width-changing cast, public VMI + ABI, unsafe masked_load tail, and unaligned/dynamic group memory remain + explicit capability boundaries. +``` + +## 2. Layout Domain + +Layout is a property of a layout-assigned VMI value, not a property inferred by +the final lowering pattern. + +Type policy: + +```text +storage boundary: + f8-like/i8/f16/i16/f32/i32 may appear in load/store values when the target + memory instruction supports the physical width. + +cast boundary: + f8-like participates through extf/truncf. + i8 participates through extsi/extui/trunci. Signedness is carried by the + cast op semantics, not by a separate layout. + On the current VPTO target, 32-bit to 8-bit integer narrowing is only a + baseline lowering for unsigned i8 results because the available VCVTII forms + are s32/u32 -> u8. + +compute boundary: + baseline floating compute uses f16/f32. + baseline integer grouped reduction compute uses i32 accumulators. i8/i16 + storage must be widened first because integer reduction instructions widen + narrow inputs. + f8/i8 are not baseline accumulator/compute element types. + +value-indexed accumulation boundary: + pto.vmi.dhist consumes ui8 source lanes and produces a logical 256xui16 + accumulator/result. It is not a group_reduce family member because result + bins are selected by source values rather than by source lane/group position. + pto.vmi.chist uses the same surface shape only after the target CHISTv2 + range semantics are verified. +``` + +### 2.1 Dense Layouts + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` defaults to `1`: + +```text +#pto.vmi.layout + == #pto.vmi.layout +``` + +Dense layouts preserve one semantic value for every logical lane. + +Lane map for `deinterleaved = F, block_elems = B`: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +Important consequence: + +```text +deinterleaved=2, block_elems=1 +deinterleaved=2, block_elems=8 +``` + +are different layouts. They cannot be treated as compatible because `F` is the +same. + +See `vmi-lane-stride-generalization-design.md` for the planned extension that +allows dense layouts to carry `lane_stride` as an additional lane-map field. +That extension keeps dense lane stride separate from the existing group-slot +carrier lowering use case. Non-zero lane phase is left as a future extension +and is not required for the first dense-stride optimization. + +### 2.2 Group-Slot Layouts + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Only `G` lanes have semantic values: + +```text +slot_block(g) = g / K +slot_lane(g) = (g % K) * LS +``` + +All non-slot lanes are undefined and may only be read by group-aware operations. +Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. + +`LS` defaults to 1 and is measured in logical element-sized physical slots. It +is not a new group semantic; it records regular physical spacing for each stored +group slot. For example, `ui8 lane_stride=4` maps slot values to byte lanes +0, 4, 8, ... and lets `group_store` lower through a b32 carrier `PK4_B32` +store. + +`K` is selected by the assigned producer/result contract: + +```text +S=8/16/32 packed VCG result -> slots=8 +S=64 row-local result -> slots=1 +``` + +Histogram does not add a layout family. A full logical histogram result uses: + +```text +!pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +and physicalizes to two ordered VPTO parts: + +```text +part0 = logical bins 0..127 +part1 = logical bins 128..255 +``` + +The VPTO `#bin` selector is therefore an op-local lowering detail, not a VMI +layout attribute and not a user-visible operand on `pto.vmi.dhist`. + +## 3. Lowering Context Must Become Explicit IR Output + +`vmi-to-vpto` may inspect only: + +```text +1. op name and explicit op attrs +2. converted operand/result types with layout +3. helper/materialization ops written by layout assignment +4. inserted helper ops +5. target capability registry +``` + +It must not: + +```text +1. walk to defining op to infer layout +2. inspect all users to choose a lowering path +3. infer memory legality from a later mask +4. decide S=16 block_elems=1 vs block_elems=8 locally +5. decide whether group_broadcast should be materialized for one or many users +6. specialize function signatures during vmi-to-vpto +``` + +Any of those decisions belongs to the layout stage before `vmi-to-vpto`. + +## 4. Explicit Assignment Products + +After `vmi-layout-assignment`, every VMI data and mask value must be in one of +these states: + +```text +layout-assigned type: + !pto.vmi.vreg> + !pto.vmi.mask> + +or explicit helper: + pto.vmi.ensure_layout + pto.vmi.ensure_mask_layout + pto.vmi.ensure_mask_granularity +``` + +`vmi-to-vpto` is allowed to choose a deterministic lowering from local +information on the current op: + +```text +current op name +current op attrs +operand/result types and layouts +current op operand values such as stride and offset +target capability and pass options +``` + +This is not context inference. What remains forbidden is walking to producers, +users, sibling users, branch/loop bodies, callees/callers, or nearby memory/MTE +ops to recover a lowering decision or a memory-safety proof. + +If a decision cannot be made from that local information, layout assignment +must rewrite the IR until the decision is explicit in attrs, operand/result +layouts, helper ops, or diagnostics. Later optimization passes may replace +helpers with cloned/rematerialized producers, but `vmi-to-vpto` must not +consume a separate string lowering-plan attr. + +### 4.1 Local Lowering Contract + +The lowering path is derived from op + assigned operand/result layouts + +explicit attrs/operands. If two legal lowerings cannot be distinguished from +that local information, the IR is missing a semantic carrier and must be +extended before that lowering is implemented. + +The shared abstraction is a layout fact classifier, not a central lowering-plan +registry. A classifier may answer questions such as: + +```text +cast layout fact: + f16/i16 -> f32/i32 requires contiguous source and deinterleaved=2 result + f8/i8 -> f32/i32 requires contiguous source and deinterleaved=4 result + f32/i32 -> f16/i16 requires deinterleaved=2 source and contiguous result + f32/i32 -> f8/i8 requires deinterleaved=4 source and contiguous result + +group_reduce layout fact: + define E = sizeof(accumulator T), VLaneElems = 32B / E, + L = 256B / E, S = N / G. + S == VLaneElems requires contiguous source/mask and + group_slots(G, slots=8) result. + S == 2 * VLaneElems requires deinterleaved=2 source/mask and + group_slots(G, slots=8) result. + S == 4 * VLaneElems requires deinterleaved=4 source/mask and + group_slots(G, slots=8) result. + S >= L && S % L == 0 requires contiguous source/mask and + group_slots(G, slots=1) result. + +memory safety fact: + full physical chunks are legal for pointer sources. Partial logical loads + need a shaped safe-tail memref proof or an explicit fallback option. +``` + +These helpers return semantic layout requirements and capability diagnostics. +They do not return VPTO instruction names, cost decisions, clone decisions, or +multi-user plans. + +The useful shared fact is the part that would otherwise be recomputed by two or +more stages and must stay identical for correctness: + +```text +cast width ratio: + assignment uses it to request source/result layouts and insert ensure_layout. + validation uses it to reject unsupported assigned cast shapes. + lowering uses it to check the local op shape before emitting VPTO. + +group_reduce lane partition: + assignment uses N/G and accumulator element width to request source/mask and + result layouts. + validation uses the same math to reject legacy or incomplete group_slots. + lowering uses the already assigned layouts to select the local VPTO sequence. + +layout materialization shape: + assignment may insert ensure_layout without proving every physical sequence. + validation and lowering use one support query to decide whether that explicit + helper is materializable on the target. + optimization uses the same query only when it wants to fold/sink/remove an + explicit helper. +``` + +The helper is not useful when it only renames one local pattern. A single +`if (is this op with this attr)` that is not shared by assignment, validation, +lowering, or an optimization should stay local to that pass. The support layer +exists to prevent divergent layout math, not to move every branch into a table. + +Forbidden non-local lowering recovery: + +```text +No pattern may recover a lowering decision or memory proof by: + - walking from group_reduce to the load/group_load producer + - walking from store/broadcast/truncf to the group_reduce producer + - scanning sibling users of a group_slots value + - inspecting branch bodies or loop bodies from a control-flow boundary + - inspecting private callee bodies while lowering a call +``` + +If the current op lacks enough local information, `vmi-to-vpto` emits +`VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, +assigned layouts, and the missing decision class. + +## 5. Layout Requests, Helpers, And Optimization + +The compiler must not carry a target-aware lowering-plan registry as the shared +contract between assignment, optimization, validation, and lowering. The +shared contract is: + +```text +1. assigned layouts on VMI types +2. explicit use-site helpers: ensure_layout, ensure_mask_layout, + ensure_mask_granularity +3. explicit op attrs/operands that are part of the semantic op +4. small layout fact classifiers shared only where they remove duplicated + layout math +5. target capability diagnostics +``` + +This split makes optimization simpler only when optimization is phrased as +rewriting explicit helper IR: + +```text +baseline: + %x_d2 = pto.vmi.extf %x_f16 + %a = pto.vmi.addf %x_d2, %k_d2 + %a_c = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + pto.vmi.store %a_c, %out0 + %x_c = pto.vmi.ensure_layout %x_d2 : deinterleaved=2 -> contiguous + pto.vmi.store %x_c, %out1 + +fold-consumers: + checks only each local ensure_layout + store use. + If VMILayoutSupport says the store can preserve row-major memory from the + source layout, rewrite that use to store the source directly. + It does not inspect sibling users of %x_d2 and does not recompute the layout + assignment. + +rematerialize: + checks only cheap producer + ensure_layout. + If the producer can directly create the requested layout, clone/rematerialize + that producer for the use. + Memory producers such as group_slot_load are excluded until a separate proof + says cloning is semantically and economically valid. + +sink-materialization: + checks only explicit ensure_* operands of a layout-transparent op. + If every operand helper is compatible, rebuild the op in the source layout and + leave one ensure_* on the result. +``` + +If an optimization needs a global cost decision, it should produce a new +explicit IR shape and then rely on canonicalize/CSE. It must not communicate a +private decision to `vmi-to-vpto`. + +### 5.1 Baseline Dense Layout Requests + +```text +f16 -> f32: + source contiguous f16 + result deinterleaved=2, block_elems=1 + +f8 -> f32: + source contiguous f8 + result deinterleaved=4, block_elems=1 + +f32 -> f16: + source deinterleaved=2, block_elems=1 + result contiguous f16 + +f32 -> f8: + source deinterleaved=4, block_elems=1 + result contiguous f8 + +elementwise dense: + all dense operands/results share the same layout + +dense store: + requests contiguous source + if the stored value is assigned deinterleaved, baseline assignment inserts + ensure_layout at the store use + +two-way interleaved memory ops: + `pto.vmi.deinterleave_load` produces two dense logical streams and requests + contiguous layouts for both results + `pto.vmi.interleave_store` consumes two dense logical streams and requests + contiguous layouts for both inputs + the deinterleave/interleave memory pattern is op semantics, not a VMI layout +``` + +### 5.2 Baseline Group Layout Requests + +```text +group_reduce_add{f|i}: + uses the group_reduce layout fact in section 4.1. + The source and mask operands request the computed dense layout. + The result is assigned group_slots(G, slots=8) or group_slots(G, slots=1). + Floating-point `group_reduce_addf` carries `reassoc`; integer + `group_reduce_addi` does not. + +group_slot_load: + result group_slots(G, slots=8) for packed slots + result group_slots(G, slots=1) for row-local slots + +group_broadcast: + source requests group_slots(G,K) + result requests one dense layout + incompatible dense consumers are represented by ensure_layout after the + broadcast result; a later optimization may clone/rematerialize the broadcast + +group_store: + source requests group_slots(G,K) + explicit output stride attrs/operands decide store legality + +group_slot_cast f32 -> f16: + slots=1 row-local source/result is legal + slots=8 packed source is illegal unless a future explicit helper or semantic + op defines the packed slot-preserving transform +``` + +### 5.3 Tail And Memory Safety + +Mask semantics and memory legality are separate: + +```text +mask: + decides which logical lanes participate in compute/store semantics + +full_footprint_readable: + decides whether a rounded-up physical load is allowed to read inactive lanes +``` + +The full-tile-readable proof must be explicit. It may be carried by a +statically shaped memref source. Pointer-source runtime kernels should load a +rounded physical vector and use a mask to express logical active lanes. +`vmi-to-vpto` consumes only the op/type-local proof carrier; it does not inspect +surrounding MTE copies, producer bodies, callers, or later consumers to decide +whether inactive physical lanes are safe to read. + +Example: + +```text +S=32 tail num_groups=6: + without full_footprint_readable: + fast DINTLV_B32 full-tile load is illegal + + with full_footprint_readable: + full 8-row physical tile may be loaded + compute mask is PAT_VL48 per physical part + group store mask is PAT_VL6 + +S=16 grouped tail active_elems_per_group=12: + low 8-lane row half uses PAT_ALL + high 8-lane row half uses lane_mod_8 < 4 + the same split applies before and after group_broadcast + +one mask used by f32 and f16 consumers: + f32 use materializes a b32 predicate + f16 use materializes a b16 predicate + vmi-to-vpto consumes the assigned per-use mask materialization +``` + +### 5.4 Case-Driven Request Matrix + +The first implementation should build requests from the following finite table. +This table is deliberately case-derived; adding a new request kind requires a +new catalog case or a proof that it is equivalent to one listed here. + +```text +dense store: + requests dense contiguous source + if source is deinterleaved, baseline assignment inserts ensure_layout at the + store use. A later optimization may fold that helper into a layout-aware + store lowering such as vstsx2. + +truncf f32 -> f16: + requests source deinterleaved=2, block_elems=1 + requests result contiguous f16 + +truncf f32 -> f8: + requests source deinterleaved=4, block_elems=1 + requests result contiguous f8 + +group_reduce_add{f|i}: + computes E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems requests source contiguous and result group_slots(G, slots=8) + S=2*VLaneElems requests source deinterleaved=2 and result + group_slots(G, slots=8) + S=4*VLaneElems requests source deinterleaved=4 and result + group_slots(G, slots=8) + S>=L && S%L==0 requests source contiguous and result + group_slots(G, slots=1) + 8-bit storage reaches this request only after an explicit cast to the + accumulator type + +group_broadcast: + requests source group_slots(num_groups, slots=K) + produces one assigned dense result layout + incompatible dense consumers are represented by ensure_layout uses; a later + optimization may clone/rematerialize the group_broadcast per consumer + +group_store: + requests source group_slots(num_groups, slots=K) + explicit output stride attrs/operands decide store legality + +dense elementwise add/mul/fma/min/max/select: + requests all dense data operands and results use one dense layout + mask operands request the same data layout and the consumer element + granularity + +group-slot elementwise add/mul/select: + requests all group-slot operands and results use the same + group_slots(num_groups, slots=K) + rejects mixing dense and group_slots without explicit group_broadcast or + group_store + +group_slot_load: + requests result group_slots(num_groups, slots=8) for packed unit-stride slots + requests result group_slots(num_groups, slots=1) for row-local aligned slots + +group_load: + requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block + fragments, or contiguous for row-local full chunks + +masked_load: + requests result layout from its consumers + requests mask layout matching the result + requires explicit passthrough; padding is not synthesized + +masked_store: + requests dense source layout required by the store op + requests mask layout matching the source layout and store element granularity + does not choose memory safety for an earlier load + +create_mask/create_group_mask: + produces one assigned mask layout and granularity + incompatible mask consumers are represented by ensure_mask_layout or + ensure_mask_granularity; optimization may clone/rematerialize the mask op + +dhist: + requests acc/result contiguous !pto.vmi.vreg<256xui16> + requests source contiguous !pto.vmi.vreg + requests mask contiguous with b8 granularity + lowers each 256-lane source chunk by carrying two accumulator parts: + bins 0..127 use VPTO histogram #bin=0, bins 128..255 use #bin=1 + final partial source chunks are represented by AND-ing the user mask with a + valid-lane prefix mask before the VPTO histogram op + +chist: + same layout requests as dhist + baseline lowering is disabled until target capability records whether the + high-range VPTO cumulative result is global or range-local + +scf.if/scf.for/call/return: + requests equality across carried VMI values, yielded values, call operands, + callee arguments, and function results + baseline private/internal functions materialize at boundaries; optimization + may specialize signatures + public/external VMI boundaries are diagnostics until an ABI is defined +``` + +Important negative requests: + +```text +ordinary dense add/mul/store/truncf cannot request group_slots +packed group_slots(slots=8) cannot request width-changing cast unless a packed +slot-preserving cast transform is explicitly represented +slots=1 group_store cannot request unit-stride row-major output until a pack or +unaligned-store transform is explicitly represented +``` + +### 5.5 Optimization Hooks + +Baseline assignment resolves incompatible use-site requests by keeping one +assigned layout on the value and inserting explicit helpers at the use sites +that need another layout. It does not clone producers, rematerialize cheap +ops, choose memory-fused layouts by cost, or specialize private function +signatures for performance. + +Those choices belong to later VMI layout optimization passes. They consume +the explicit helper IR and may rewrite it when the rewrite preserves the same +logical value and externally visible memory effect: + +```text +ensure_layout + store: + fold into a layout-aware store if the store can directly consume the source + layout and still write row-major memory + +producer + ensure_layout: + clone/rematerialize the producer for that use only when the producer is cheap + or has an explicit safe-read proof + +elementwise chain + ensure_layout: + sink or hoist materialization through pure layout-transparent ops + +group_broadcast + incompatible dense consumers: + type each group_broadcast op for its consumer layout; do not force one result + layout across independent group_broadcast users + +create_mask/create_group_mask + incompatible mask consumers: + clone/rematerialize the mask producer per layout or predicate granularity + +private function boundary: + specialize function signatures only in an optimization pass; baseline + assignment materializes at boundary uses +``` + +If no helper materialization or optimization rewrite is legal, the diagnostic +must name the value's assigned layout, the use-site requested layout, and the +op that requested it. + +## 6. Layout Assignment Algorithm + +`vmi-layout-assignment` is module-level. It must see function/call/control-flow +connections before choosing layouts. + +### 6.1 Variables + +Create a layout variable for: + +```text +1. every VMI OpResult +2. every VMI BlockArgument +3. every function argument/result that is allowed to carry VMI +4. every VMI mask value +``` + +Create a use-site request for: + +```text +1. every operand use that requires a specific layout +2. every control-flow yield/branch/call/return edge +3. every memory operation that requires an explicit memory legality proof +``` + +### 6.2 Constraints + +Hard constraints: + +```text +group_slots cannot feed ordinary dense consumers +direct group-slot width-changing cast requires an explicit slot-preserving transform +public/external VMI function boundary requires a stable ABI or diagnostic +S=32 fast tail load requires full_footprint_readable or gather fallback +``` + +`slots = 1` row-local cast may satisfy the slot-preserving transform requirement. +Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast +or unpack/materialization transform is represented explicitly. + +Equivalence constraints: + +```text +dense add/mul/select: + operands/results use same dense layout unless an explicit materialization is + inserted at a use site + +scf.if/scf.for: + region yield operands and block arguments must have the same assigned layout + as the region result/iter_arg +``` + +Canonical baseline constraints: + +```text +S=16 group_reduce: + request deinterleaved=2; baseline uses block_elems=1 unless the producer + result already carries block_elems=8 as an explicit layout + +one dense value feeding S=16 and S=32 group_reduce: + keep the value's assigned layout and insert ensure_layout at both use sites + that need deinterleaved=2 or deinterleaved=4 + +load/group_load: + use the op's assigned result layout and explicit memory-safety attrs only + +group_broadcast: + keep one assigned dense result layout and communicate other dense use layouts + through ensure_layout +``` + +### 6.3 Solving + +Recommended solving order: + +```text +1. Build function/control-flow SCCs. +2. Collect natural producer layouts and hard use-site layout requests. +3. Propagate equality constraints through dense elementwise ops and CFG edges. +4. Choose one deterministic assigned layout for each value or equivalence + class. +5. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses + whose requested layout differs from the assigned layout. +6. Emit diagnostics for unsupported semantic constraints or missing explicit + memory-safety proofs. +7. Rewrite VMI types and insert explicit helper ops. +``` + +Tie-breaking must be deterministic and deliberately simple. Suggested priority: + +```text +1. Preserve an explicit user-provided layout attr. +2. Preserve a unique producer natural layout when present. +3. Preserve an equality-class non-contiguous layout when required by a hard op. +4. Otherwise choose contiguous. +``` + +## 7. Control Flow And Functions + +### 7.1 `scf.if` + +All branch yields for one result must agree on one assigned layout. If they do +not, assignment inserts materialization before `scf.yield` where possible. +The `scf.if` result type after assignment carries that layout, so +`vmi-to-vpto` does not need to inspect either branch body. + +### 7.2 `scf.for` + +Loop-carried VMI values are fixed-point variables: + +```text +initial iter_arg layout +body block argument layout +yield operand layout +loop result layout +``` + +must converge to one layout. If a body consumer needs another layout, it is a +use-site request inside the loop body. +The loop body block argument has no defining op. Its layout is therefore part +of the block argument type after assignment, not information reconstructed from +the initial value or previous iteration during lowering. + +### 7.3 Calls + +Internal/private VMI function boundaries must make layout choices explicit in +the assigned IR. The baseline implementation keeps function arguments in a +contiguous VMI ABI and inserts callee-entry `ensure_layout` helpers when the +callee body needs another layout. Private helpers are then physicalized by +`vmi-to-vpto` and inlined before VPTO vecscope/backend emission so physical +`!pto.vreg`/`!pto.mask` values do not become a backend function ABI. A later +private-function optimization may specialize signatures directly: + +```text +func @producer() -> !vmi.vreg<256xf32, deinterleaved=4> +``` + +then physicalized by `vmi-to-vpto` into multiple VPTO function results. + +Public/external VMI function boundaries are rejected until a stable VMI ABI is +defined. + +## 8. vmi-to-vpto Contract + +`vmi-to-vpto` receives layout-assigned VMI. It performs no global reasoning. + +For each op, the pattern: + +```text +1. reads operand/result layouts +2. reads current op attrs and operand values +3. asks TypeConverter for ordered physical values +4. emits the locally implied VPTO lowering +5. fails if target capability or required local proof is absent +``` + +The pattern must not: + +```text +1. inspect all users to decide result layout +2. inspect defining ops to decide source layout +3. choose between S=16 block_elems=1 and block_elems=8 +4. decide whether a load is full_footprint_readable +5. decide function signature specialization +``` + +Allowed local reads are deliberately narrower: + +```text +arith.constant defining op: + allowed only to materialize an operand of the current op, such as + create_mask active_lanes or a constant memory offset + +current VMI op body/attrs: + allowed for op-local semantics, such as create_group_mask + active_elems_per_group when lowering the create_group_mask op itself + +helper materialization chain: + allowed only to strip ensure_mask_layout / ensure_mask_granularity for + static predicate analysis that does not choose a different layout or lowering + +diagnostic embellishment: + allowed only to improve an already-failed capability message, such as naming + memref.subview after identity lane-to-address planning has failed +``` + +Anything else is a layout-assignment responsibility. In particular, an +unsupported producer/consumer combination must be rejected before assignment +emits layout-assigned IR. Section 3.44 is the model for supported partial S=32 +grouped masks: assignment emits explicit contiguous and deinterleaved mask +values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through +contiguous grouped-mask materialization followed by predicate deinterleave. It +does not walk from `group_reduce_addf` to the mask producer to choose or reject +the lowering. Dynamic `active_elems_per_group` follows the same rule: the +`create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps +for contiguous chunks before any predicate deinterleave. + +## 9. Physical Value Ordering + +The OneToN lowering order is fixed. + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=F: + part0_chunk0, part0_chunk1, ..., + part1_chunk0, part1_chunk1, ..., + ... + part(F-1)_chunk0, ... + +group_slots(G,K): + slot_block0, slot_block1, ... +``` + +Two physical bundle entries may alias the same VPTO SSA value when the current +op semantics prove they have the same contents, such as group_broadcast feeding both +parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; +aliasing is not a different layout. + +## 10. Diagnostics + +Diagnostics are part of the design. They must name: + +```text +1. the VMI op +2. source logical type +3. assigned source layout +4. requested layout +5. missing local proof or disabled fallback +6. suggested rewrite when available +``` + +Examples: + +```text +dense store of group_slots: + use group_store, group_broadcast, or explicit group-pack + +packed group-slot f32->f16: + group_broadcast before truncf, or keep group_store as f32 + +S=32 tail without full_footprint_readable: + mark source full_footprint_readable or enable stable gather fallback + +S=32 group_load with unaligned source_group_stride: + choose a stride divisible by 8 f32 elements or enable stable gather fallback + +public VMI function boundary: + make function internal, inline before assignment, or define ABI layout +``` + +## 11. Implementation Migration Checks + +The design is useful only if the implementation removes duplicated decision +points instead of renaming them. The migration target is: + +```text +assignment: + computes assigned layouts, records use-site requests, inserts ensure_* helpers, + and diagnoses unsupported semantics + does not clone/rematerialize producers + does not choose memory-fused layouts by cost + does not inspect sibling users to optimize a value + +layout optimization: + consumes explicit ensure_* helpers + may fold ensure_layout into layout-aware consumers + may clone/rematerialize cheap producers + may sink/hoist materialization through pure elementwise chains + may specialize private function signatures + +vmi-to-vpto: + consumes current op attrs/operands, assigned operand/result layouts, and + explicit helper ops + performs local physical shape and target-capability checks + does not recover layout plans from producers, sibling users, CFG regions, or + callees/callers +``` + +Concrete implementation debt to remove: + +```text +1. Move assignment-side data/mask rematerialization into + vmi-layout-rematerialize. Baseline assignment should insert ensure_* for + mismatched uses. +2. Keep `VMILayoutSupport` as target capability and layout-shape queries, not + as a shared plan table. Group-reduce layout math now lives in + `getPreferredGroupReduceLayoutFact`. Dense cast layout shape now lives in + `getPreferredCastLayoutFact`. Helper materialization gates use + `canMaterializeDataLayout`, `canMaterializeMaskLayout`, and + `canMaterializeMaskGranularity`. +3. Assignment, validation, and lowering may call layout fact helpers, but must + not each independently derive VLaneElems/groupSize/factor/slots rules. +4. Keep store-fold, rematerialization, and sink/hoist as local rewrites over + explicit ensure_* IR. They must not walk sibling users to rediscover why the + helper exists. +5. Update pass descriptions, diagnostics, and tests so "assignment only" output + is legal with helpers, and optimized output is a separate, equivalent IR + form. +``` + +Regression tests should prove the boundary: + +```text +assignment only: + multi-consumer values keep one assigned layout and use ensure_* at mismatched + uses + +fold-consumers: + ensure_layout + store becomes a layout-aware store only when the consumer can + preserve the same row-major memory effect + +rematerialize: + cheap producer + ensure_layout becomes a cloned/rematerialized producer; with + the pass disabled, the ensure_layout form remains legal + +vmi-to-vpto: + rejects any residual need for producer/user context with VMI-LAYOUT-CONTRACT +``` + +## 12. Design Completion Criteria + +The design is complete only when: + +```text +1. every case in vmi-layout-lowering-cases.md maps to assignment requests, + explicit helpers, or a precise diagnostic +2. every VMI-to-VPTO lowering can be emitted without looking at producer/user + context +3. every unsupported case has a precise capability diagnostic +4. every control-flow/function boundary materializes, specializes in an + optimization pass, or diagnoses +5. every mask has explicit data layout and predicate granularity +6. every positive case has end-to-end lit coverage +7. every simulator-supported positive case has simulator validation +``` diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md new file mode 100644 index 0000000000..9b26ecbde7 --- /dev/null +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -0,0 +1,6447 @@ +# VMI Layout Lowering Cases + +本文是 VMI layout/lowering 的典型 case catalog,不是完整设计总文档。它只回答一个问题: +一个 VMI logical vector 在某个场景下选择某种 layout 后,`vmi-to-vpto` 必须生成什么 +VPTO 结果。这里不写动机式描述;每个场景都给出 layout assignment 和 lowering result。 + +## 1. Layout Families + +### 1.1 Dense Layout + +Dense layout 的每个 logical lane 都有语义值。 + +```text +#pto.vmi.layout +``` + +Physical ordering: + +```text +chunk c, lane l -> logical lane c * L + l +``` + +`L` is the physical lanes per 256B VPTO vector register for the element type. + +```text +#pto.vmi.layout +``` + +`block_elems` defaults to `1`. Existing spellings are shorthands: + +```text +#pto.vmi.layout + == #pto.vmi.layout + +#pto.vmi.layout + == #pto.vmi.layout +``` + +Logical-to-physical mapping: + +```text +logical lane i +block q = i / B +in_block lane r = i % B +part p = q % F +part_block t = q / F + +physical part p, physical lane t * B + r +``` + +Required invariants: + +```text +F > 0 +B > 0 +N % (F * B) == 0 for the direct full-chunk paths in this document +``` + +### 1.2 Group-Slot Layout + +Group-slot layout is not dense. Only `G` lanes have semantic values. + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Physical slot mapping: + +```text +N = logical lane count +S = N / G // logical lanes per source group + +slot_block(g) = g / K +slot_lane(g) = (g % K) * LS +``` + +Required invariants: + +```text +G > 0 +K > 0 +G % K == 0 +K must fit in the physical vreg element count +LS > 0 +``` + +`LS` defaults to 1 and is counted in logical element-sized physical slots. It +is used when the group result value is intentionally stored with a regular lane +gap. For example, `ui8 lane_stride=4` places group slots in byte positions 0, +4, 8, ... and can be lowered to a b32 carrier plus `PK4_B32` store. + +`K` is selected by the producer/consumer layout support rule. It is not always 8. For +`VCGADD`-packed results, `K = 8` matches the eight 32B block results written to +the low lanes of one destination vreg. For row-local reductions where each +logical group already occupies one full 256B vreg, `K = 1` keeps each group's +scalar result in lane 0 of its own physical vreg and avoids an unsupported +cross-vreg scalar pack. + +Only these lanes are semantic: + +```text +physical slot block slot_block(g), lane slot_lane(g) +``` + +All other lanes are undefined for ordinary VMI consumers. They may only be read +by group-aware ops that define how to interpret group slots. + +## 2. Layout Support Selection Rules + +VMI cast ops must not hard-code one physical `vcvt` lowering as their semantic +layout rule. Layout assignment records the required value layout; target +support queries only answer whether that layout can be materialized or lowered. + +```text +dense cast: + source/result are dense layouts. + lowering may require deinterleaved(F, block_elems=1) around VCVT. + +group-slot cast: + source/result are both group_slots(G,K). + lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are + legal only when slot-preserving VPTO lowering support exists, or when the cast + can be commuted through a later group-aware consumer such as group_broadcast. +``` + +Illegal consumer mix: + +```text +group_slots value -> ordinary dense store/add/mul +``` + +This must fail unless an explicit semantic op converts the group-slot value: + +```text +group_broadcast +group_store +future explicit group-pack op +``` + +Contiguous memory loads may produce a non-contiguous physical value directly +when the requested result layout is a dense deinterleaved layout. This is a +lowering choice, not a separate layout family. + +```text +pto.vmi.load -> #pto.vmi.layout + lower as: + vlds NORM for each physical chunk + +pto.vmi.load -> #pto.vmi.layout + lower as: + vldsx2 DINTLV_B* for each pair of physical chunks + +pto.vmi.load -> #pto.vmi.layout + lower as: + two vldsx2 DINTLV_B* operations for each four-chunk group + followed by two vdintlv operations to split mod4 parts + +pto.vmi.load -> #pto.vmi.layout + lower using the producer-specific path or fall back to explicit + materialization. Do not treat DINTLV_B* as a block-fragment layout. +``` + +The `deinterleaved = 4` result order remains the normal VMI physical part +order: + +```text +results = [part0 chunks..., part1 chunks..., part2 chunks..., part3 chunks...] +``` + +For one full `256xf32` tile: + +```text +%even0, %odd0 = pto.vldsx2 %base[%off0], "DINTLV_B32" +%even1, %odd1 = pto.vldsx2 %base[%off128], "DINTLV_B32" + +%part0, %part2 = pto.vdintlv %even0, %even1 +%part1, %part3 = pto.vdintlv %odd0, %odd1 + +replace pto.vmi.load with [%part0, %part1, %part2, %part3] +``` + +This optimization is legal only for full physical chunks and supported +`DINTLV_B8/B16/B32` element widths. Tail and masked loads keep their explicit +safe lowering until a masked or guarded `vldsx2` strategy is designed. + +Two-way logical interleaved memory access is represented by dedicated VMI ops, +not by exposing assigned layouts in surface IR: + +```mlir +%x, %y = pto.vmi.deinterleave_load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + +pto.vmi.interleave_store %x, %y, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.ptr +``` + +Each VMI value is an ordinary dense logical vector. Layout assignment requests +contiguous layouts for both streams. Lowering maps full-chunk 8/16/32-bit cases +to `vldsx2 DINTLV_B*` and `vstsx2 INTLV_B*`. + +## 3. Lowering Results + +The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true +predicate with the element granularity required by the instruction. `PAT_VLk` +means a prefix predicate for the first `k` lanes. + +Completeness rule for this section: every numbered endpoint below must contain +VMI input, assigned layouts, VPTO lowering result, and either a memory result or +an explicit diagnostic. Non-endpoint layout notes may appear only as setup for +the immediately following complete endpoints. + +```text +3.1 f16 -> f32 -> store complete +3.2 f32 -> f16 -> store complete +3.3 f8 -> f32 -> compute -> f8 complete +3.4 group_reduce S=8 -> group_store complete +3.5.1 group_reduce S=16 -> group_store complete +3.5.2 group_reduce S=16 -> broadcast -> compute -> reduce -> store + complete +3.5.3 group_reduce S=16 -> elemwise(rhs) -> group_store complete +3.6.1 group_reduce S=32 -> group_store complete +3.6.2 group_reduce S=32 -> elemwise(rhs) -> group_store complete +3.6.3 group_reduce S=32 -> broadcast -> compute -> reduce -> store + complete +3.7.1 group_reduce S=64 -> aligned group_store complete +3.7.2 group_reduce S=64 -> elemwise(rhs) -> aligned group_store + complete +3.7.3 group_reduce S=64 -> broadcast -> compute -> reduce -> store + complete +3.7.4 group_reduce S=64 -> unit-stride group_store illegal diagnostic +3.8 group_reduce -> truncf -> broadcast -> dense store complete +3.9 dense store of group slots illegal diagnostic +3.10 non-load producer feeding S=32 group_reduce complete +3.11 partial tail groups complete/diagnostic +3.12 control-flow join before group_reduce complete +3.13 packed group-slot f32 -> f16 cast illegal diagnostic +3.14 unsupported group size illegal diagnostic +3.15 compact S=12 written as logical S=16 complete/diagnostic +3.16 group_slot_load layout contract complete +3.17 group_broadcast feeding deinterleaved consumer complete +3.18 one value with dense and group-reduce consumers complete/materialization +3.19 S=16 reduce block_elems support selection complete/diagnostic +3.20 group_slots control-flow join complete +3.21 S=32 tail with full-tile-readable source complete +3.22 scf.for loop-carried layout complete +3.23 group_broadcast with multiple dense consumers complete +3.24 mask with elementwise/select/store complete +3.25 function boundary layout specialization complete +3.26 S=16 grouped tail through broadcast/reduce/store complete +3.27 S=32 group_load with stride greater than group size complete +3.28 group_slot_load slots=1 aligned non-unit stride complete +3.29 one semantic mask with f32 and f16 consumers complete +3.30 masked_load tail without padding complete/diagnostic +3.31 f16->f32 feeding dense store and S=16 reduce complete +3.32 f32 feeding f8 store and S=32 reduce complete +3.33 one dense value feeding S=16 and S=32 reduces complete/materialization +3.34 S=64 group-slot result f32->f16 cast complete +3.35 group_slots fanout to group_store and broadcast complete +3.36 same scalar source materialized as slots=8/slots=1 complete/materialization +3.37 S=64 group_store with non-unit output stride complete +3.38 multi-tile S=32 group_reduce complete +3.39 strided S=32 group_load through broadcast/reduce complete +3.40 scalar broadcast feeding dense and grouped users complete/materialization +3.41 non-rematerializable value with incompatible users complete/materialization +3.42 group_slots scf.for loop-carried accumulator complete +3.43 internal function argument boundary materialization complete +3.44 masked_load grouped tail feeding S=32 reduce complete +3.45 dynamic S=32 create_group_mask complete +3.46 extf value and derived elemwise value both stored complete/optimization +3.47-3.55 typed group-reduce generalization complete/diagnostic +3.56 full 256-bin distribution histogram complete +3.57 full 256-bin cumulative histogram design boundary +``` + +### 3.1 `f16 -> f32 -> store` + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +pto.vmi.store %x32, %out[%off] +``` + +Assigned layouts: + +```text +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +pto.vstsx2 %x32_p0, %x32_p1, %out[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Alternative complete VPTO lowering result if `vstsx2 INTLV_B32` is unavailable: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%x32_d0, %x32_d1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %x32_d0, %out[%off], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x32_d1, %out[%off_plus_64], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = extf(base[off + i]) +``` + +### 3.2 Dense `f32 -> f16 -> store` + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %x16, %out[%off] +``` + +Assigned layouts: + +```text +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x32_p0, %x32_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Alternative complete VPTO lowering result if the source has already been loaded +as two contiguous f32 chunks and must be materialized to `deinterleaved=2` before +the conversion: + +```text +%x32_d0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_d1 = pto.vlds %base[%off_plus_64] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_p0, %x32_p1 = pto.vdintlv %x32_d0, %x32_d1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +### 3.3 Dense `f8 -> f32 -> compute -> f8` + +VMI input: + +```text +%x8 = pto.vmi.load %base[%off] +%x32 = pto.vmi.extf %x8 +%scale = pto.vmi.broadcast %scale_s : f32 -> !pto.vmi.vreg<256xf32> +%y32 = pto.vmi.mulf %x32, %scale +%y8 = pto.vmi.truncf %y32 +pto.vmi.store %y8, %out[%off] +``` + +Assigned layouts: + +```text +%x8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%scale : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x8_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%x32_p0 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%y32_p0 = pto.vmul %x32_p0, %scale_p0, PAT_ALL_B32 +%y32_p1 = pto.vmul %x32_p1, %scale_p1, PAT_ALL_B32 +%y32_p2 = pto.vmul %x32_p2, %scale_p2, PAT_ALL_B32 +%y32_p3 = pto.vmul %x32_p3, %scale_p3, PAT_ALL_B32 + +%y8_p0 = pto.vcvt %y32_p0, PAT_ALL_B32 + {part = "P0", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p1 = pto.vcvt %y32_p1, PAT_ALL_B32 + {part = "P1", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p2 = pto.vcvt %y32_p2, PAT_ALL_B32 + {part = "P2", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p3 = pto.vcvt %y32_p3, PAT_ALL_B32 + {part = "P3", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> + +%y8_01 = pto.vor %y8_p0, %y8_p1, PAT_ALL_B8 +%y8_23 = pto.vor %y8_p2, %y8_p3, PAT_ALL_B8 +%y8_0 = pto.vor %y8_01, %y8_23, PAT_ALL_B8 + +pto.vsts %y8_0, %out[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + out[off + i] = truncf(extf(base[off + i]) * scale_s) +``` + +### 3.4 `group_reduce` S=8 f32 + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<64xf32, #pto.vmi.layout> +%mask : !pto.vmi.mask<64xpred, #pto.vmi.layout> +%sum : !pto.vmi.vreg<64xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%mask_chunk = pto.pge_b32 "PAT_ALL" + +%x_chunk = pto.vlds %base[%tile_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%sum_block = pto.vcgadd %x_chunk, %mask_chunk + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Lowering result for one chunk, per the `visa.txt` VCGADD contract: + +```text +%sum_block lane 0 = reduce %x lanes 0..7 +%sum_block lane 1 = reduce %x lanes 8..15 +... +%sum_block lane 7 = reduce %x lanes 56..63 +all non-slot lanes are non-semantic +``` + +Layout result: + +```text +G = N / 8 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..7]) +``` + +### 3.5 `group_reduce` S=16 f32, load-fused split + +The facts used by this lowering are checked against the current repo: + +```text +pto.vldsx2 supports "BDINTLV". +pto.vstsx2 supports only "INTLV_B8" / "INTLV_B16" / "INTLV_B32". +visa.txt says VCGADD writes one 32B-block result continuously to destination +LSBs; the current repository golden tests follow lanes 0..7 for f32. +``` + +There are three complete consumers for this layout today: + +```text +load -> group_reduce -> group_store(sum) +load -> group_reduce -> elementwise compute on group-slot values + -> group_store +load -> group_reduce -> group_broadcast -> elementwise compute + -> group_reduce -> group_store +``` + +#### 3.5.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> + +%sum : !pto.vmi.vreg> +``` + +For each 8-row tile: + +```text +row r = 16xf32 = row_r.lo8, row_r.hi8 +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes 0..7 = row0.lo8 +%lo lanes 8..15 = row1.lo8 +... +%lo lanes 56..63 = row7.lo8 + +%hi lanes 0..7 = row0.hi8 +%hi lanes 8..15 = row1.hi8 +... +%hi lanes 56..63 = row7.hi8 + +%lo_sum = pto.vcgadd %lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +`BDINTLV` here denotes the ISA `#bdintlv` block-based interleaving load mode: +it loads `2 * VL` bytes and sends even 32B blocks to the first destination +register and odd 32B blocks to the second destination register. For f32, +one 32B block is `8xf32`, matching `block_elems = 8`. + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Layout result: + +```text +G = N / 16 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 + +%sum_block lane 0 = reduce row0 lanes 0..15 +%sum_block lane 1 = reduce row1 lanes 0..15 +... +%sum_block lane 7 = reduce row7 lanes 0..15 +``` + +No VMI value exposes `%lo_sum` or `%hi_sum`. They are internal VPTO values. + +Memory result: + +```text +sum_out[group_tile_off + 0] = reduce row0 lanes 0..15 +sum_out[group_tile_off + 1] = reduce row1 lanes 0..15 +... +sum_out[group_tile_off + 7] = reduce row7 lanes 0..15 +``` + +This endpoint is fully specified: the only group-slot value is `%sum`; `group_store` +stores the low 8 slot lanes with an ordinary prefix store. + +#### 3.5.2 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +%b = pto.vmi.group_broadcast %sum {num_groups = N / 16} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = N / 16} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> +%sum : !pto.vmi.vreg> +%b : !pto.vmi.vreg> +%y : !pto.vmi.vreg> +%ysum : !pto.vmi.vreg> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// This is the materialization of pto.vmi.group_broadcast. The group sums are +// in %sum_block lanes 0..7; vselr expands each sum to the 8 lanes of the +// corresponding row half. The following vmul/vcgadd consume an ordinary dense +// physical vector. +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 + : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %all_b32 + : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Final per-row reduction and store. +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %ysum_block, %out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +This trace processes 8 logical rows at once. `num_groups = N / 16` means each +logical group is one `16xf32` row, and one full f32 VPTO tile covers 8 such +groups: + +```text +64 f32 lanes per physical part = 8 rows * 8 f32 lanes per half-row +``` + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Physical lane result for the tile: + +```text +%x_lo lanes 0..7 = row0[0..7] +%x_lo lanes 8..15 = row1[0..7] +... +%x_lo lanes 56..63 = row7[0..7] + +%x_hi lanes 0..7 = row0[8..15] +%x_hi lanes 8..15 = row1[8..15] +... +%x_hi lanes 56..63 = row7[8..15] + +%sum_block lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%b_rows lanes 0..7 = reduce(row0[0..15]) +%b_rows lanes 8..15 = reduce(row1[0..15]) +... +%b_rows lanes 56..63 = reduce(row7[0..15]) + +For each row `r` in this 8-row tile: + +%y_lo lanes r*8 .. r*8+7 = + row_r[0..7] * reduce(row_r[0..15]) + +%y_hi lanes r*8 .. r*8+7 = + row_r[8..15] * reduce(row_r[0..15]) + +Concretely: +%y_lo lanes 0..7 = row0[0..7] * reduce(row0[0..15]) +%y_lo lanes 8..15 = row1[0..7] * reduce(row1[0..15]) +... +%y_lo lanes 56..63 = row7[0..7] * reduce(row7[0..15]) + +%y_hi lanes 0..7 = row0[8..15] * reduce(row0[0..15]) +%y_hi lanes 8..15 = row1[8..15] * reduce(row1[0..15]) +... +%y_hi lanes 56..63 = row7[8..15] * reduce(row7[0..15]) + +%ysum_block lanes 0..7 = + reduce(%y row0), reduce(%y row1), ..., reduce(%y row7) +``` + +Memory result: + +```text +out[group_tile_off + r] = + reduce_i((row_r[i] * reduce_j(row_r[j])) for i in 0..15) + = reduce(row_r[0..15]) * reduce(row_r[0..15]) +for r = 0..7 +``` + +If a later consumer requires row-major contiguous order, `vmi-to-vpto` must +materialize: + +```text +deinterleaved=2, block_elems=8 -> contiguous +``` + +This materialization cannot be implemented with `vstsx2 INTLV_B32`, because +that instruction interleaves individual b32 elements, not 32B row halves. Until +a concrete block-interleave register materialization or store op is selected, +row-major store of this layout must be rejected with: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store requires materializing + #pto.vmi.layout to contiguous, but no + VPTO block-interleave materialization/store support exists. +``` + +#### 3.5.3 Reduce Result, Elementwise, Store + +This case computes a per-row reduction, applies an elementwise operation to the +reduced values themselves, and stores one result per group. There is no +`group_broadcast` in this flow because the elementwise op is not applied to the +original `8x16xf32` matrix elements. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x for reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +For this endpoint, the RHS is a packed per-group vector: + +```text +rhs_base[rhs_off + r] = rhs(row r), for r = 0..7 +``` + +Layout assignment must treat `group_slot_load` as a group-slot producer: one +f32 value per group is placed in the live slot lanes. It must not use +`group_load`, which loads `group_size` data elements per group instead of one +per-group scalar. + +The elementwise op runs only on the live group-slot lanes: + +```text +%sum lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%rhs lanes 0..7 = + rhs(row0), rhs(row1), ..., rhs(row7) + +%outv lanes 0..7 = + %sum lanes 0..7 + %rhs lanes 0..7 + +lanes 8..63 remain dead/zero and are masked off by PAT_VL8. +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// Reduction path: use BDINTLV to feed two VCG reductions. +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +// One 32B block contains 8 f32 RHS values and materializes lanes 0..7; all +// other lanes are dead/zero. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Elementwise compute on group-slot values. Only lanes 0..7 are live. +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[group_tile_off + r] = s + rhs[r] +``` + +### 3.6 `group_reduce` S=32 f32, 4-way split + +This case covers one `8x32xf32` tile. Each logical row is 128B, so it must be +split into four 32B partial rows before `vcgadd` can reduce it efficiently. + +The canonical layout for the input is: + +```text +%x : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +With `deinterleaved = 4`, physical part `p` contains columns whose logical +column index is `p mod 4`: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[4], row_r[8], ..., row_r[28] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[5], row_r[9], ..., row_r[29] + +%x_p2 lanes r*8 .. r*8+7 = + row_r[2], row_r[6], row_r[10], ..., row_r[30] + +%x_p3 lanes r*8 .. r*8+7 = + row_r[3], row_r[7], row_r[11], ..., row_r[31] +``` + +Each physical part now has exactly 8 f32 values per row, so one `vcgadd` per +part computes one partial sum per row. The four partial sums are then added +under `PAT_VL8`. + +The full contiguous-to-4-way materialization for one tile should fuse the first +deinterleave level into the load. `vldsx2 DINTLV_B32` loads `2 * VL` bytes and +splits even/odd f32 elements into two physical vectors. Two such loads cover +the `8x32xf32` tile, and a second register `vdintlv` level splits even columns +into `mod4 = 0/2` and odd columns into `mod4 = 1/3`. + +This setup documentation is repeated inside every complete 32-wide endpoint +below. + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +Each endpoint below inlines this materialization before the first consumer of +`%x_p0..%x_p3`. + +#### 3.6.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..31]) +``` + +#### 3.6.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..31]) + rhs[r] +``` + +#### 3.6.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// group_broadcast materialized for each deinterleaved=4 physical part. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%ys01 = pto.vadd %ys0, %ys1, %sum_mask : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %sum_mask : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..31]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..31) + = s * s +``` + +### 3.7 `group_reduce` S=64 f32, row-local reduction + +This case covers one `8x64xf32` tile. Each logical row is exactly 256B, so the +input does not need a deinterleaved layout: + +```text +row r = 64xf32 = one !pto.vreg<64xf32> +``` + +The reduction is two-stage but row-local: + +```text +vcgadd(row_r) -> 8 partial sums in lanes 0..7 +vcadd(PAT_VL8) -> one row sum in lane 0 +``` + +The result layout is therefore not `slots = 8`. It is: + +```text +#pto.vmi.layout +``` + +Physical slot mapping for this tile: + +```text +slot_block(r) = r +slot_lane(r) = 0 + +%sum0 lane 0 = reduce row0 lanes 0..63 +%sum1 lane 0 = reduce row1 lanes 0..63 +... +%sum7 lane 0 = reduce row7 lanes 0..63 +``` + +Trying to canonicalize this result to `slots = 8` would require packing lane 0 +from eight different physical vregs into lanes 0..7 of one vreg. This document +does not use that packing transform. `slots = 1` is the canonical layout for +S=64 row-local group reductions. + +#### 3.7.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %sum_out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %sum0, %sum_out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum1, %sum_out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum2, %sum_out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum3, %sum_out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum4, %sum_out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum5, %sum_out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum6, %sum_out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum7, %sum_out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r * 8] = reduce(row_r[0..63]) +``` + +#### 3.7.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +%c8 = arith.constant 8 : index +pto.vmi.group_store %outv, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%rhs0 = pto.vsldb %rhs_ptr_0, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs1 = pto.vsldb %rhs_ptr_1, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs2 = pto.vsldb %rhs_ptr_2, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs3 = pto.vsldb %rhs_ptr_3, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs4 = pto.vsldb %rhs_ptr_4, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs5 = pto.vsldb %rhs_ptr_5, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs6 = pto.vsldb %rhs_ptr_6, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs7 = pto.vsldb %rhs_ptr_7, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%out0 = pto.vadd %sum0, %rhs0, %one_b32 : !pto.vreg<64xf32> +%out1 = pto.vadd %sum1, %rhs1, %one_b32 : !pto.vreg<64xf32> +%out2 = pto.vadd %sum2, %rhs2, %one_b32 : !pto.vreg<64xf32> +%out3 = pto.vadd %sum3, %rhs3, %one_b32 : !pto.vreg<64xf32> +%out4 = pto.vadd %sum4, %rhs4, %one_b32 : !pto.vreg<64xf32> +%out5 = pto.vadd %sum5, %rhs5, %one_b32 : !pto.vreg<64xf32> +%out6 = pto.vadd %sum6, %rhs6, %one_b32 : !pto.vreg<64xf32> +%out7 = pto.vadd %sum7, %rhs7, %one_b32 : !pto.vreg<64xf32> + +pto.vsts %out0, %out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out1, %out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out2, %out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out3, %out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out4, %out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out5, %out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out6, %out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out7, %out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r * 8] = reduce(row_r[0..63]) + rhs[r] +``` + +#### 3.7.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %ysum, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// The compiler emits this row-local block once for each r in 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// This vdup is the lowering of pto.vmi.group_broadcast for slots=1. +%b_r = pto.vdup %sum_r, %all_b32 {position = "LOWEST"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%y_r = pto.vmul %x_r, %b_r, %all_b32 : !pto.vreg<64xf32> + +%yp_r = pto.vcgadd %y_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_r = pto.vcadd %yp_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %ysum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +The row-local block above is not a runtime loop requirement. It is the repeated +VPTO shape for row offsets `%row_off_0` through `%row_off_7` and store offsets +`%group_tile_off_0` through `%group_tile_off_7`. + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..63]) + out[group_tile_off + r * 8] = + reduce_i(row_r[i] * s for i = 0..63) + = s * s +``` + +#### 3.7.4 Slots=1 Store Lowers To Packed Or Point Stores + +The row-local S=64 result uses one physical vreg per group with the semantic +value in lane 0: + +```text +%sum_r lane 0 = reduce(row_r[0..63]) +``` + +The current VPTO lowering for `slots = 1` group_store has two paths. + +For unit-stride output where all groups fit in one physical vector, the +lowering packs the lane-0 values into one dense vector and stores that vector +with a normal `vsts`. + +For non-unit row strides, each group stores its lane-0 scalar with a point +store. That emits `vsts` with `dist = "1PT_B32"` for f32 and only requires the +natural 4B alignment of the scalar element. + +VMI input: + +```text +%c1 = arith.constant 1 : index +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Current checked-in coverage for the point-store path is: + +```text +test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto +``` + +### 3.8 `group_reduce -> truncf -> group_broadcast -> store` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +%b16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +pto.vmi.store %b16, %out[%off] +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum32 : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum16 : semantic value only; not materialized as a group-slot VPTO value +%b32_dense : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b32_split : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +This case is supported by commuting `truncf` after `group_broadcast`: + +```text +group_broadcast(truncf(group_reduce(x))) + == truncf(group_broadcast(group_reduce(x))) +``` + +This avoids materializing a group-slot f16 value. Current lowering makes the +layout transition explicit: `group_broadcast` first produces a dense contiguous +f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 +view required by dense `f32 -> f16` truncation. A future direct +`group_broadcast -> deinterleaved=2` lowering may remove that materialization, +but the `group_broadcast` result layout must make that support path explicit rather +than hiding it inside `truncf` lowering. + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum32_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx_lo = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx_hi = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> + +// These vselr ops are the VPTO lowering of pto.vmi.group_broadcast for the two +// dense contiguous f32 physical chunks. +%b32_rows_lo = pto.vselr %sum32_block, %broadcast_idx_lo + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b32_rows_hi = pto.vselr %sum32_block, %broadcast_idx_hi + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=2 materializes the two f32 parity +// inputs expected by f32 -> f16 truncation. +%b32_even_input, %b32_odd_input = pto.vdintlv %b32_rows_lo, %b32_rows_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%b16_even = pto.vcvt %b32_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%b16_odd = pto.vcvt %b32_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%b16 = pto.vor %b16_even, %b16_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %b16, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s32 = reduce(row_r[0..15]) + s16 = truncf(s32) + out[r * 16 + 0 .. r * 16 + 15] = splat(s16) +``` + +### 3.9 Illegal Dense Consumer Of Group Slots + +VMI input: + +```text +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = G} +pto.vmi.store %sum32, %out[%off] +``` + +Assigned layouts before the illegal consumer: + +```text +%sum32 : group_slots(G,K) +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store cannot consume #pto.vmi.layout + as a dense vector. Use pto.vmi.group_store, pto.vmi.group_broadcast, or an + explicit group-pack op. +``` + +It must not be diagnosed as: + +```text +dense store materializes group slots implicitly +``` + +That behavior would silently reinterpret a group-slot value as a dense +vector. + +### 3.10 Non-Load Producer Feeding S=32 `group_reduce` + +This case proves that layout assignment is consumer-driven. The producer of the +S=32 input is an elementwise op, not a load. The S=32 `group_reduce` still +requires the elementwise result to be `deinterleaved = 4`, and that requirement +must propagate backward through the elementwise op to both operands. + +VMI input: + +```text +%a = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.addf %a, %bias +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %bias, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full `8x32xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%a_even_0, %a_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_even_1, %a_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%bias_p0 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p1 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p2 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p3 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%x_p0 = pto.vadd %a_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> +%x_p1 = pto.vadd %a_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> +%x_p2 = pto.vadd %a_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> +%x_p3 = pto.vadd %a_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce_i(base[row_r, i] + bias_s for i = 0..31) +``` + +### 3.11 Partial Tail Groups + +Tail handling must be separated by the physical input layout. Row-local S=64 +can avoid inactive rows entirely. Load-fused S=16/S=32 cannot safely do that +with the current `vldsx2` materialization unless the source is known to be +full-tile readable. + +#### 3.11.1 S=64 Active Row Tail + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<384xf32> -> !pto.vmi.vreg<384xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %out[%group_off], %c8 {num_groups = 6} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local block for r = 0..5 only. No load or store is emitted for +// rows 6 and 7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_tile_off + r * 8] = reduce(row_r[0..63]) +``` + +#### 3.11.2 S=32 Tail Without Full-Tile Read Contract + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layout requested by the consumer: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +Required diagnostic when the source does not carry a full-tile-readable +contract: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with group size 32 and num_groups tail 6 requires + materializing #pto.vmi.layout. The fast lowering support + uses vldsx2 DINTLV_B32 over a full 8-row tile. This source is not marked + full-tile-readable, and the stable gather tail fallback is not implemented. +``` + +If a future option enables the stable gather tail fallback, the same VMI input +may lower by gathering only the active lanes. Until that support exists, the +converter must not silently issue the full-tile `vldsx2` loads. + +### 3.12 Control-Flow Join Before `group_reduce` + +The layout carried by a value must survive block arguments. In MLIR converter +terms, the logical VMI value lowered through control flow becomes a tuple of +physical VPTO values with one tuple type per assigned layout. + +VMI input: + +```text +%x = scf.if %cond -> !pto.vmi.vreg<256xf32> { + %a = pto.vmi.load %a_base[%a_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %a : !pto.vmi.vreg<256xf32> +} else { + %b = pto.vmi.load %b_base[%b_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %b : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %b, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = + scf.if %cond + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %a_even_0, %a_odd_0 = pto.vldsx2 %a_base[%a_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_even_1, %a_odd_1 = pto.vldsx2 %a_base[%a_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %a_p0, %a_p1, %a_p2, %a_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + %b_even_0, %b_odd_0 = pto.vldsx2 %b_base[%b_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_even_1, %b_odd_1 = pto.vldsx2 %b_base[%b_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p0, %b_p2 = pto.vdintlv %b_even_0, %b_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p1, %b_p3 = pto.vdintlv %b_odd_0, %b_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %b_p0, %b_p1, %b_p2, %b_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +``` + +The consumer after the join uses the same S=32 reduction lowering support as +section 3.6: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + selected_row = cond ? a_row_r : b_row_r + out[group_tile_off + r] = reduce(selected_row[0..31]) +``` + +If the two branches cannot be assigned the same layout and no materialization +support exists before `scf.yield`, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + scf.yield joins incompatible VMI layouts for !pto.vmi.vreg<256xf32>. + Expected #pto.vmi.layout on every incoming value. +``` + +### 3.13 Packed Group-Slot `f32 -> f16` Cast + +This case is intentionally illegal for the current S=16/S=32 packed +group-slot layout. It prevents the compiler from treating a width-changing +`vcvt` as if it preserved low-lane group slots. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts before the illegal cast: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf cannot lower from + #pto.vmi.layout f32 to f16 because no + slot-preserving width-changing VPTO support exists. f32->f16 vcvt writes + even/odd sub-lanes, not lanes 0..7. Use group_broadcast before truncf, or + keep the group_store element type as f32. +``` + +This does not contradict section 3.8. Section 3.8 is legal because the cast is +commuted after `group_broadcast`, where the value is dense again. + +### 3.14 Unsupported Group Size + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<96xf32> -> !pto.vmi.vreg<96xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based lowering +support uses 32B groups, i.e. 8 f32 elements per row fragment: + +```text +S = 8 -> one VCGADD block per group +S = 16 -> two 8-lane row fragments, add partial sums +S = 32 -> four 8-lane row fragments, add partial sums +S = 64 -> one full 256B row, VCGADD then VCADD +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with f32 group size 12 has no supported VPTO + layout/lowering path. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. + A scalar/gather fallback or a rewrite to logical group size 16 with an + explicit per-group mask is required. +``` + +### 3.15 Compact S=12 Written As Logical S=16 + +If the program wants to use the S=16 lowering for data with 12 semantic f32 +elements per group, the IR must distinguish two sizes: + +```text +logical group size used by VMI ops: 16 +active elements per group: 12 +``` + +The mask is not a prefix mask over the whole vector. It is a per-group mask: + +```text +mask lane i is active iff (i % 16) < 12 +``` + +The group load surface carries the physical source stride as an SSA operand: + +```text +%x = pto.vmi.group_load %base[%off], %source_group_stride + {num_groups = G, group_size = S} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +`source_group_stride` is in elements, not bytes. It is an operand because it may +come from a dynamic leading dimension, a subview, or a runtime tile descriptor. +Static strides use a constant index operand and can be canonicalized later. +`group_size` remains an attribute in this design because it selects the logical +load layout. `active_elems_per_group` belongs to the mask producer, not to the +load. + +Grouped masks use a paired `pto.vmi.create_group_mask` op. It is intentionally +separate from ordinary prefix `pto.vmi.create_mask` so the IR makes group +semantics explicit next to `pto.vmi.group_load` / `pto.vmi.group_reduce_*`: + +```text +%mask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = G, group_size = S} + : index -> !pto.vmi.mask<(G*S)xpred> +``` + +Semantics: + +```text +lane i is active iff (i % S) < active_elems_per_group +``` + +Current lowering support covers constant `active_elems_per_group`. Dynamic +grouped masks require a runtime lane-index predicate materializer and remain a +separate implementation item. + +Ordinary `pto.vmi.create_mask %active_lanes` keeps the prefix-mask meaning: + +```text +lane i is active iff i < active_lanes +``` + +#### 3.15.1 Existing Design Works If Source Row Stride Is 16 + +If memory already has a 16-f32 row stride, the user can write a logical S=16 +tile and mask off the last four lanes of every group. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x32_for_store: + pto.vmi.ensure_layout %x32 + : #pto.vmi.layout -> #pto.vmi.layout +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+3 = row_r[8..11] +%hi lanes r*8+4 .. r*8+7 = row_r[12..15] // inactive by mask + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..11]) +``` + +Design requirement added by this case: VMI mask lowering must support +group-periodic masks by generating the predicate from lane indices. It must not +rewrite this mask to `PAT_M4`: VISA defines `M4` as multiples of 4, not the +first four lanes of each 8-lane block. + +```text +lane = vci(0) +row = lane >> 3 +col = lane - (row << 3) +mask = col < 4 +``` + +#### 3.15.2 Source Row Stride Greater Than 16 + +For now, support the non-compact case where each physical row has at least 16 +f32 slots and the row stride is greater than 16. The fast strided-block path +requires the row stride to be a multiple of one 32B block: + +```text +source_group_stride % 8 == 0 +``` + +The example below uses `source_group_stride = 24`. Each row has 12 semantic +values, 4 masked-but-readable slots, and 8 extra skipped slots: + +```text +row_r[0..11] semantic +row_r[12..15] readable but inactive for the S=16 logical group +row_r[16..23] outside the logical group +``` + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts are the same as section 3.15.1: + +```text +%x, %mask: + #pto.vmi.layout +%sum: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +// source_group_stride = 24 f32 = 3 * 32B blocks. +%stride_blocks = %c3_i16 + +%base_lo = %base + tile_off +%base_hi = %base + tile_off + 8 + +%lo = pto.vsldb %base_lo, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%hi = pto.vsldb %base_hi, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+7 = row_r[8..15] + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 24 + 0 .. tile_off + r * 24 + 11]) +``` + +If `source_group_stride > 16` but is not a multiple of 8 f32 elements, this +strided-block path is not legal because `vsldb` block addresses are 32B based. +That case remains unsupported until a gather materialization is selected. + +#### 3.15.3 Compact Source Row Stride 12 + +Compact storage is explicitly out of scope for the first implementation: + +```text +row0[0..11], row1[0..11], row2[0..11], ... +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + logical group size 16 with active_elems_per_group 12 and + source_group_stride 12 requires compact-row gather materialization. This + plan is not part of the initial VMI layout lowering. +``` + +### 3.16 `group_slot_load` Layout Contract + +`group_slot_load` is separate from `group_load`. + +```text +group_load: + loads group_size data elements per group and produces dense grouped data. + +group_slot_load: + loads one scalar value per group and produces group slots. +``` + +Surface form: + +```text +%v = pto.vmi.group_slot_load %base[%off], %source_group_stride + {num_groups = G} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +Semantics: + +```text +semantic group slot g = base[off + g * source_group_stride] +``` + +The result logical lane count `N` remains the surrounding VMI value shape. Only +the `G` group slots are semantic. Layout assignment chooses the group-slot physical +placement requested by the consumer: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +#### 3.16.1 Packed `group_slot_load`, `slots = 8` + +VMI input: + +```text +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%slot_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// source_group_stride = 1, so one 32B block contains all 8 scalar group slots. +%rhs_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for g = 0..7: + out[group_off + g] = rhs_base[rhs_off + g] +``` + +If `source_group_stride != 1`, this packed `slots = 8` layout requires a +strided/gather group-slot load materializer. Until that support exists, +`group_slot_load` with `slots = 8` and non-unit stride must diagnose instead of +silently using full-group `group_load`. + +#### 3.16.2 Row-Local `group_slot_load`, `slots = 1` + +VMI input: + +```text +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. Each result value carries one semantic slot +// in lane 0, matching the S=64 row-local group_reduce result layout. +// For f32, source_group_stride = 8 elements = 32B, so every lane-0 vsldb is +// aligned. +%rhs_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_off_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Current lowering rule: + +```text +slots = 1 group_slot_load uses one lane-0 vsldb per semantic group slot. +For f32, source_group_stride must be a positive constant divisible by 8 +elements. For f16 it must be divisible by 16 elements, and for f8 it must be +divisible by 32 elements. +``` + +### 3.17 `group_broadcast` Feeding A Deinterleaved Consumer + +This case fixes a lowering invariant: `group_broadcast` itself does not infer a +consumer-specific deinterleaved result. It produces the layout selected by +layout assignment. If a later consumer requires another layout, assignment must +insert an explicit `ensure_layout`. + +The current endpoint is: + +```text +group_reduce -> group_broadcast(contiguous f32) + -> ensure_layout(deinterleaved = 2) + -> truncf(contiguous f16) +``` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_dense: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_split = pto.vmi.ensure_layout %b_dense: + #pto.vmi.layout + -> #pto.vmi.layout + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +// group_broadcast lowers to two contiguous f32 chunks. +%idx_lo = materialize indices [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%idx_hi = materialize indices [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> + +%b_lo = pto.vselr %sum_block, %idx_lo + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_hi = pto.vselr %sum_block, %idx_hi + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=2 is explicit in assigned VMI. +%b_even_input, %b_odd_input = pto.vdintlv %b_lo, %b_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%h_even = pto.vcvt %b_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %b_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +Required assignment rule: + +```text +`group_broadcast` layout is chosen before `vmi-to-vpto`. A width-changing +consumer such as `truncf` may require a deinterleaved f32 source, but that +requirement must be represented by `ensure_layout`; `truncf` lowering must not +look through the defining `group_broadcast` and choose a hidden broadcast shape. +``` + +### 3.18 One Value With Dense And Group-Reduce Consumers + +This case forces layout assignment to handle a solvable use-site conflict. One +consumer requires an S=32 group-reduce layout; another consumer requires dense +row-major store. This is not semantically illegal. It must be solved by +explicit use-site materialization. A later optimization pass may fold the +materialization into a store or rematerialize a cheap producer when the required +support exists. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x, %copy_out[%off] +``` + +Assigned layouts: + +```text +%x for group_reduce: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x for dense store: + requires #pto.vmi.layout +``` + +Baseline layout assignment keeps `%x` in the group-reduce layout and inserts +`ensure_layout` before the dense store use. A later rematerialization pass may +clone the load for the dense store if that is profitable. A later fold-consumer +pass may also fold `ensure_layout + store` into a layout-aware store lowering. + +VPTO lowering result: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Dense store materialization for the second consumer. +%even0, %even1 = pto.vintlv %x_p0, %x_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%odd0, %odd1 = pto.vintlv %x_p1, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%d0, %d1 = pto.vintlv %even0, %odd0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%d2, %d3 = pto.vintlv %even1, %odd1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %d0, %copy_out[%off_0], %all_b32 {dist = "NORM_B32"} +pto.vsts %d1, %copy_out[%off_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %d2, %copy_out[%off_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %d3, %copy_out[%off_192], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..31]) + +for i = 0..255: + copy_out[off + i] = base[off + i] +``` + +If `deinterleaved = 4 -> contiguous` materialization support does not exist, the +required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + value %x is required as #pto.vmi.layout by + pto.vmi.group_reduce_addf and as #pto.vmi.layout by + pto.vmi.store, but no materialization support exists at the store use site. +``` + +### 3.19 S=16 Reduce `block_elems` Support Selection + +S=16 f32 group reduction has two legal dense input layouts: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems = 1` is the element-parity layout required by f32->f16 `truncf`. +It is also a valid S=16 reduction layout: each physical part contains eight +values per row, so `VCGADD` can reduce each part and `VADD` can combine the two +partial sums. + +`block_elems = 8` is still useful when the producer is a block load shape such +as `BDINTLV` or `vsldb` over 32B row fragments. Baseline layout assignment must +express any mismatch with an explicit `ensure_layout`; producer rematerialization +or consumer folding can choose the cheaper equivalent form later. Assignment +must not hard-code S=16 reduce to `block_elems = 8`. + +#### 3.19.1 Continuous S=16 Reduce And Truncf, `block_elems = 1` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical lane map: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[2], row_r[4], ..., row_r[14] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[3], row_r[5], ..., row_r[15] +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_p0, %x_p1 = pto.vldsx2 %base[%tile_off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..15]) + +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +#### 3.19.2 Block-Load Producer Fixed To `block_elems = 8` + +This is the real conflict case. The value is fixed to `block_elems = 8` +because the producer uses block-load support. A later `truncf` +requires element-parity `block_elems = 1`. + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts before the conflicting `truncf` use: + +```text +%x from strided block group_load: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +The reduction path is legal and uses the same `vsldb` block-load shape as +section 3.15.2. The `truncf` path is legal only if one of these transforms +exists: + +```text +1. rematerialize the original memory producer as block_elems=1 +2. materialize block_elems=8 -> block_elems=1 in registers +3. use an explicitly enabled scratch/reload fallback +``` + +If no such transform exists, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the strided + group_load. Add rematerialization or preserving materialization support, or + avoid consuming this block-loaded value with truncf. +``` + +### 3.20 `group_slots` Control-Flow Join + +`group_slots` values must be allowed to cross control flow. The join type is a +group-slot physical tuple, not a dense vector. + +VMI input: + +```text +%sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + %b = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + scf.yield %b : !pto.vmi.vreg<128xf32> +} +%bias = pto.vmi.group_slot_load %bias_base[%bias_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%outv = pto.vmi.addf %sum, %bias +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%a, %b, %sum, %bias, %outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%sum_block = scf.if %cond -> !pto.vreg<64xf32> { + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + + %x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %a_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + scf.yield %a_block : !pto.vreg<64xf32> +} else { + %one_block = pto.pge_b32 "PAT_VL1" + %b_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + scf.yield %b_block : !pto.vreg<64xf32> +} + +%one_block = pto.pge_b32 "PAT_VL1" +%slot_mask = pto.pge_b32 "PAT_VL8" +%bias_block = pto.vsldb %bias_base[%bias_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%out_block = pto.vadd %sum_block, %bias_block, %slot_mask + : !pto.vreg<64xf32> + +pto.vsts %out_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + lhs = cond ? reduce(row_r[0..15]) : rhs_base[rhs_off + r] + out[group_off + r] = lhs + bias_base[bias_off + r] +``` + +### 3.21 S=32 Tail With Full-Tile-Readable Source + +This is the positive counterpart to section 3.11.2. Tail participation is +still expressed by masks, but the source must provide a static proof that +reading the rounded-up 8-row physical tile is memory-safe. That proof is +explicit for partial logical loads: it can come from a statically shaped memref +source. Pointer-source runtime kernels should instead load the rounded physical +vector and use a mask to express active logical lanes; this is not inferred from +surrounding MTE copies or caller context. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Equivalent pointer-source VMI input for runtime kernels: + +```text +%x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<192xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// A statically safe full-read proof allows the load plan to read the +// rounded-up 8-row tile. Only rows 0..5 are semantically active. +%x_c0 = pto.vlds %base[%tile_off_0] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c1 = pto.vlds %base[%tile_off_1] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c2 = pto.vlds %base[%tile_off_2] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c3 = pto.vlds %base[%tile_off_3] + : memref<256xf32> -> !pto.vreg<64xf32> + +%x_lo01, %x_hi01 = pto.vdintlv %x_c0, %x_c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_lo23, %x_hi23 = pto.vdintlv %x_c2, %x_c3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x_lo01, %x_lo23 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_hi01, %x_hi23 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%data_mask0, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask1, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask2, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask3, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%sum_mask, %_ = pto.plt_b32 %c6_i32 + : i32 -> !pto.mask, i32 + +%s0 = pto.vcgadd %x_p0, %data_mask0 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %data_mask1 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %data_mask2 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %data_mask3 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_off + r] = reduce(row_r[0..31]) +``` + +Rows 6 and 7 may be physically loaded because of the safe full-read proof, but +their lanes are not active in `%data_mask*`, and their group slots are not +stored because `%sum_mask` is produced by `plt_b32 %c6_i32`. + +### 3.22 `scf.for` Loop-Carried Layout + +Loop-carried VMI values require a layout fixed point. The iter_arg, body block +argument, yield operand, loop result, and later consumer must all agree on one +layout, or `vmi-layout-assignment` must insert a materialization at a legal +dominating use site. + +VMI input: + +```text +%init = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%acc = scf.for %i = %c0 to %steps step %c1 + iter_args(%arg = %init) -> !pto.vmi.vreg<256xf32> { + %bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> + %next = pto.vmi.addf %arg, %bias + scf.yield %next : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %acc, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%init, %arg, %bias, %next, %acc: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%init_even_0, %init_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_even_1, %init_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%init_p0, %init_p2 = pto.vdintlv %init_even_0, %init_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_p1, %init_p3 = pto.vdintlv %init_odd_0, %init_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%acc_p0, %acc_p1, %acc_p2, %acc_p3 = + scf.for %i = %c0 to %steps step %c1 + iter_args(%arg_p0 = %init_p0, %arg_p1 = %init_p1, + %arg_p2 = %init_p2, %arg_p3 = %init_p3) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %all_b32 = pto.pge_b32 "PAT_ALL" + %bias_p0 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p1 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p2 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p3 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + + %next_p0 = pto.vadd %arg_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> + %next_p1 = pto.vadd %arg_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> + %next_p2 = pto.vadd %arg_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> + %next_p3 = pto.vadd %arg_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + scf.yield %next_p0, %next_p1, %next_p2, %next_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%s0 = pto.vcgadd %acc_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %acc_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %acc_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %acc_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + for c = 0..31: + acc[row_r, c] = base[row_r, c] + steps * bias_s + out[group_off + r] = reduce(acc[row_r, 0..31]) +``` + +### 3.23 `group_broadcast` With Multiple Dense Consumers + +One `group_slots` value may feed multiple `group_broadcast` uses with different +dense result layout requirements. Each `group_broadcast` op has its own result +layout, so layout assignment should type each op at its use site instead of +forcing one result layout onto all consumers. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + +%b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b_for_mul +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b_for_cast +pto.vmi.store %h, %dense_out[%off] +``` + +Assigned layouts in the current implementation: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_mul, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%b_for_cast: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_cast_split: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +The important invariant is not that both dense consumers choose the same dense +layout. It is that each use has an explicit layout boundary: + +```text +%x_for_reduce = pto.vmi.ensure_layout %x +%y_for_reduce = pto.vmi.ensure_layout %y +%b_for_cast_split = pto.vmi.ensure_layout %b_for_cast +``` + +If a future direct `group_broadcast -> deinterleaved` support path is added, layout +assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but +the choice must still be visible in the assigned IR. + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Use 1: broadcast for the multiply path. Current lowering materializes two +// contiguous f32 chunks, multiplies them with the original contiguous chunks, +// then deinterleaves the product for the second group_reduce. +%b_rows_for_mul_0 = pto.vselr %sum_block, %broadcast_idx_0 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_rows_for_mul_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%y0 = pto.vmul %x0, %b_rows_for_mul_0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b_rows_for_mul_1, %all_b32 : !pto.vreg<64xf32> +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> +pto.vsts %ysum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Use 2: rematerialize broadcast for the f32->f16 parity cast path. +%b_rows_for_cast_0 = pto.vselr %sum_block, %broadcast_idx_0 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_rows_for_cast_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%cast_lo, %cast_hi = pto.vdintlv %b_rows_for_cast_0, %b_rows_for_cast_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%h_even = pto.vcvt %cast_lo, %all_b32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %cast_hi, %all_b32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 : !pto.vreg<128xf16> +pto.vsts %h0, %dense_out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) + dense_out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +### 3.24 Mask With Elementwise, Select, And Store + +This case separates compute masking from memory effects. A masked elementwise +operation with passthrough semantics can be represented as ordinary compute +plus `select`; a masked store uses the mask only on the store effect. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%rhs = pto.vmi.load %rhs_base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c48 + : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.addf %x, %rhs +%passthrough = pto.vmi.select %mask, %sum, %x +pto.vmi.store %passthrough, %dense_out[%off] +pto.vmi.masked_store %sum, %masked_out[%off], %mask +``` + +Assigned layouts: + +```text +%x, %rhs, %sum, %passthrough: + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<64xpred, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%m, %_ = pto.plt_b32 %c48_i32 : i32 -> !pto.mask, i32 + +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%rhs0 = pto.vlds %rhs_base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%sum0 = pto.vadd %x0, %rhs0, %all_b32 : !pto.vreg<64xf32> + +%pass0 = pto.vsel %sum0, %x0, %m + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %pass0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +pto.vsts %sum0, %masked_out[%off], %m {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..63: + if i < 48: + dense_out[off + i] = base[off + i] + rhs_base[off + i] + masked_out[off + i] = base[off + i] + rhs_base[off + i] + else: + dense_out[off + i] = base[off + i] + masked_out[off + i] is unchanged +``` + +### 3.25 Function Boundary Layout Specialization + +Function boundaries cannot rely on hidden layout side tables. Either the +function is internal and layout-specialized by `vmi-layout-assignment`, or a +public/external VMI boundary must diagnose until a stable VMI ABI is defined. + +#### 3.25.1 Internal Function Specialized To Consumer Layout + +VMI input: + +```text +func.func private @producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} + +func.func @caller(%base: !pto.ptr, %off: index, %out: !pto.ptr) { + %x = call @producer(%base, %off) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + return +} +``` + +Assigned layouts: + +```text +@producer result: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x in @caller: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @producer(...) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + return %x_p0, %x_p1, %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> +} + +func.func @caller(...) { + %x_p0, %x_p1, %x_p2, %x_p3 = call @producer(...) + : (...) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +Memory result: + +```text +for r = 0..7: + out[off + r] = reduce(row_r[0..31]) +``` + +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +ptoas pipeline: + vmi-layout-assignment makes the private result layout explicit + vmi-to-vpto physicalizes the private helper result into !pto.vreg values + ptoas then inlines private physical VMI helpers before VPTO vecscope/backend + emission, so physical vector values do not escape through a function return +``` + +#### 3.25.2 Public Or External VMI Boundary + +VMI input: + +```text +func.func @public_producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> attributes {public} { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} +``` + +Required diagnostic for the initial design: + +```text +VMI-LAYOUT-CONTRACT: + public or external function boundary returns !pto.vmi.vreg<256xf32> without a + stable VMI layout ABI. Mark the function internal for layout specialization, + inline it before vmi-layout-assignment, or define an explicit ABI layout. +``` + +### 3.26 S=16 Grouped Tail Through Broadcast, Reduce, Store + +This case extends section 3.15.1 from `reduce -> group_store` to the full +grouped compute path. It is needed because `create_group_mask` must remain a +group-periodic mask after a `group_broadcast`; it cannot collapse to a prefix +mask or an all-true mask. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %all_b32 + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %all_b32, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%broadcast_idx = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %hi4_mask : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..11]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..11) + = s * s +``` + +Required assignment rule: + +```text +%mask is a grouped mask with S=16 and active_elems_per_group=12. +For the low half, the physical predicate is PAT_ALL. +For the high half, the physical predicate is lane_mod_8 < 4. +The same split must be reused for both group_reduce operations. +``` + +### 3.27 S=32 `group_load` With Stride Greater Than Group Size + +This case is the S=32 counterpart to section 3.15.2. The logical group is +`32xf32`, but rows in memory have a larger stride. The fast plan is legal only +when the stride is a multiple of one 32B f32 block. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +// source_group_stride = 40 f32 = 5 * 32B blocks. +%stride_blocks = %c5_i16 + +%frag0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%frag0 lanes r*8 .. r*8+7 = row_r[0..7] +%frag1 lanes r*8 .. r*8+7 = row_r[8..15] +%frag2 lanes r*8 .. r*8+7 = row_r[16..23] +%frag3 lanes r*8 .. r*8+7 = row_r[24..31] + +%s0 = pto.vcgadd %frag0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %frag1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %frag2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %frag3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 40 + 0 .. tile_off + r * 40 + 31]) +``` + +Required diagnostic when the stride is not block-aligned: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_load group_size 32 with source_group_stride not divisible by + 8 f32 elements cannot use the vsldb strided-block lowering support. Enable a + stable gather fallback or choose a block-aligned source_group_stride. +``` + +Required assignment rule: + +```text +This producer requires the S=32 block-fragment layout: + #pto.vmi.layout + +It must not be unified with the contiguous-load S=32 plan from section 3.6: + #pto.vmi.layout + +Both layouts are legal inputs to group_reduce_addf S=32, but they require +different producer materialization/lowering support. +``` + +### 3.28 `group_slot_load` `slots = 1` With Aligned Non-Unit Stride + +Section 3.16.1 diagnoses non-unit stride for the packed `slots = 8` plan. The +row-local `slots = 1` plan supports non-unit stride only when each one-lane +load can be issued as an aligned `vsldb`. In the current lowering this means +the stride is a positive compile-time constant and is divisible by the 32B +alignment expressed in source elements. + +VMI input: + +```text +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. The address expression is scalar/index +// arithmetic outside the vector register layout. For f32, %c8 is 32B. +%addr_r = %rhs_base + %rhs_off + r * 8 +%rhs_r = pto.vsldb %addr_r, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Required assignment rule: + +```text +If a non-unit-stride group_slot_load has only slots=1 consumers and its stride +is a positive constant divisible by the element count of 32B, select +group_slot_load_slots1_row_local. Do not diagnose it using the slots=8 +unit-stride restriction. +``` + +Required diagnostic: + +```text +%c2 = arith.constant 2 : index +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + slots=1 group_slot_load currently lowers as one lane-0 vsldb per group and + requires constant positive source_group_stride divisible by 8 elements for + 32B load alignment; packed or unaligned scalar load lowering is not + implemented. +``` + +Dynamic stride has the same status until a stable gather or scalarized packed +load plan is designed: + +```text +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %runtime_stride + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + requires constant positive source_group_stride divisible by 8 elements. +``` + +### 3.29 One Semantic Mask With f32 And f16 Consumers + +One VMI mask may feed consumers with different physical predicate +granularities. Layout assignment must keep the semantic mask value single, but +materialize per-use physical masks after element type is known. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xpred> +pto.vmi.masked_store %x, %out32[%off], %mask +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.masked_store %h, %out16[%off], %mask +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%x_for_cast: + pto.vmi.ensure_layout %x + : #pto.vmi.layout -> #pto.vmi.layout + +%mask_for_h_store: + pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical mask materialization: + +```text +use at masked_store %x: + predicate granularity b32, PAT_VL96, layout contiguous + +use at vcvt %x -> %h: + predicate granularity b32, PAT_ALL. The cast may compute inactive lanes + because the following masked_store controls the external memory effect. + +use at masked_store %h: + predicate granularity b16, PAT_VL96, layout contiguous +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%mask32_0 = pto.pge_b32 "PAT_ALL" +%mask32_1 = pto.pge_b32 "PAT_VL32" + +%x0 = pto.vlds %base[%off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +pto.vsts %x0, %out32[%off], %mask32_0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x1, %out32[%off_plus_64], %mask32_1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%x_p0, %x_p1 = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pset_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +%mask_b16, %scalar_out = pto.plt_b16 %c96_i32 + : i32 -> !pto.mask, i32 +pto.vsts %h0, %out16[%off], %mask_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..95: + out32[off + i] = base[off + i] + out16[off + i] = truncf(base[off + i]) + +for i = 96..127: + out32[off + i] is unchanged + out16[off + i] is unchanged +``` + +Required assignment rule: + +```text +`vmi-to-vpto` must not decide mask granularity by inspecting users. It consumes +the per-use typed mask materialization inserted by vmi-layout-assignment. For +a rematerializable `create_mask`, assignment may clone it as b32/b16 masks. For +a non-rematerializable mask producer, assignment must insert +`ensure_mask_granularity` or diagnose if no materialization support exists. +``` + +### 3.30 `masked_load` Tail Without Padding + +This case is the replacement for `vector.transfer_read` padding semantics in the +initial VMI surface. Tail lanes are expressed by a mask and a passthrough value; +there is no implicit padding constant in the load. The direct lowering is legal +only when every physical chunk read by `vlds` is memory-safe. + +VMI input: + +```text +%c100 = arith.constant 100 : index +%mask = pto.vmi.create_mask %c100 : index -> !pto.vmi.mask<100xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<100xf32> +%x = pto.vmi.masked_load %base[%c0], %mask, %zero + : memref<128xf32>, !pto.vmi.mask<100xpred>, !pto.vmi.vreg<100xf32> + -> !pto.vmi.vreg<100xf32> +pto.vmi.store %x, %out[%c0] +``` + +Assigned layouts: + +```text +%mask: + !pto.vmi.mask<100xb32, #pto.vmi.layout> + +%zero, %x: + !pto.vmi.vreg<100xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%m0 = pto.pge_b32 "PAT_ALL" +%m1 = pto.pge_b32 "PAT_VL36" + +%zero0 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%c0] + : memref<128xf32> -> !pto.vreg<64xf32> +%x0 = pto.vsel %l0, %zero0, %m0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%l1 = pto.vlds %base[%c64] + : memref<128xf32> -> !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %m1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %x0, %out[%c0], %m0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %x1, %out[%c64], %m1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for i = 0..99: + out[i] = base[i] + +for i = 100..127: + out[i] is unchanged +``` + +Required diagnostic when the source cannot prove a safe full-read footprint: + +```text +VMI-UNSUPPORTED: + pto.vmi.masked_load direct lowering requires a supported memory source, + contiguous result/passthru/mask layouts, and either full physical chunks or a + statically safe full-read footprint. Use a memref with enough static extent, + enable the future stable masked/gather load plan, or make the logical vector a + full physical chunk. +``` + +Required assignment rule: + +```text +`masked_load` requests contiguous result, passthru, and mask layouts. Padding +is not a layout decision; it is the explicit passthrough operand selected by the +user. +``` + +### 3.31 `f16 -> f32` Feeding Dense Store And S=16 Reduce + +This case proves that the `deinterleaved = 2` layout produced by widening +`f16 -> f32` is not just a store layout. It must also be a legal S=16 grouped +reduction input. Layout assignment must not force the reduce consumer to +`block_elems = 8` and then rematerialize the widened value. + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assigned layouts: + +```text +%x16: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%x32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x16_0 = pto.vlds %base[%off] + : memref<128xf16> -> !pto.vreg<128xf16> +%x32_p0 = pto.vcvt %x16_0, %all_b16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, %all_b16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%dense0, %dense1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %dense0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %dense1, %dense_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(extf(base[off + r * 16 + 0 .. off + r * 16 + 15])) + +for i = 0..127: + dense_out[off + i] = extf(base[off + i]) +``` + +Required assignment rule: + +```text +When S=16 group_reduce consumes an existing `deinterleaved = 2` dense value, +the reduce plan must accept `block_elems = 1`. `block_elems = 8` is only a +producer-driven fast plan for block-fragment loads, not the semantic +requirement of S=16 reduction. +``` + +### 3.32 `f32` Feeding f8 Store And S=32 Reduce + +This is the `f32 -> f8` counterpart to section 3.31. A 256-lane f32 value can +serve both `truncf -> f8` and S=32 group reduction with the same +`deinterleaved = 4, block_elems = 1` layout. The value must not be forced to a +block-fragment `block_elems = 8` layout unless its producer requires that plan. + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8> +pto.vmi.store %x8, %out8[%off] +``` + +Assigned layouts: + +```text +%x32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%off], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%off_plus_128], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%x8_p0 = pto.vcvt %x_p0, %all_b32 {part = "P0", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p1 = pto.vcvt %x_p1, %all_b32 {part = "P1", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p2 = pto.vcvt %x_p2, %all_b32 {part = "P2", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p3 = pto.vcvt %x_p3, %all_b32 {part = "P3", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> + +%x8_01 = pto.vor %x8_p0, %x8_p1, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_23 = pto.vor %x8_p2, %x8_p3, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_0 = pto.vor %x8_01, %x8_23, PAT_ALL_B8 + : !pto.vreg<256xf8> + +pto.vsts %x8_0, %out8[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, memref<256xf8>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) + +for i = 0..255: + out8[off + i] = truncf(base[off + i]) +``` + +Required assignment rule: + +```text +The common layout selected for `%x32` is +`#pto.vmi.layout`. This satisfies both +`truncf f32 -> f8` and S=32 `group_reduce_addf`. A later strided block-load +producer may introduce `block_elems = 8`, but that is a different case and +requires an explicit materialization/rematerialization decision. + +When `%x32` is produced by a full contiguous `pto.vmi.load`, `vmi-to-vpto` +should not first materialize four contiguous f32 chunks and then run a full +four-op `vdintlv` tree. The load lowering should fold the first deinterleave +level into two `vldsx2 DINTLV_B32` operations and then run only the second +`vdintlv` level, as shown above. The layout remains just +`deinterleaved = 4, block_elems = 1`; it does not encode the fact that `vldsx2` +was used. +``` + +### 3.33 One Dense Value Feeding S=16 And S=32 Reduces + +This case is a pure layout-assignment conflict. The same logical +`256xf32` value is consumed by two legal reductions, but their efficient input +layouts are different: + +```text +S=16 reduce over 16 groups: + #pto.vmi.layout + +S=32 reduce over 8 groups: + #pto.vmi.layout +``` + +The program is semantically legal. Baseline layout assignment solves it by +inserting an explicit use-site `ensure_layout`. A later optimization pass may +clone or rematerialize the cheap load for one use. `vmi-to-vpto` must not +inspect both users and choose one locally. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%mask16 = pto.vmi.create_group_mask %c16 {num_groups = 16, group_size = 16} + : index -> !pto.vmi.mask<256xpred> +%sum16 = pto.vmi.group_reduce_addf %x, %mask16 {num_groups = 16} +pto.vmi.group_store %sum16, %out16[%group_off16], %c1 {num_groups = 16} + +%mask32 = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask32 {num_groups = 8} +pto.vmi.group_store %sum32, %out32[%group_off32], %c1 {num_groups = 8} +``` + +Assigned layouts after rematerializing the load: + +```text +%x_s16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask16: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask32: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum8_mask = pto.pge_b32 "PAT_VL8" + +// Rematerialized S=16 use. The first vldsx2 covers rows 0..7, the second +// covers rows 8..15. Each pair is deinterleaved by element parity. +%s16_p0, %s16_p1 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_p2, %s16_p3 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s16_0 = pto.vcgadd %s16_p0, %all_b32 : !pto.vreg<64xf32> +%s16_1 = pto.vcgadd %s16_p1, %all_b32 : !pto.vreg<64xf32> +%s16_2 = pto.vcgadd %s16_p2, %all_b32 : !pto.vreg<64xf32> +%s16_3 = pto.vcgadd %s16_p3, %all_b32 : !pto.vreg<64xf32> + +%sum16_lo = pto.vadd %s16_0, %s16_1, %sum8_mask + : !pto.vreg<64xf32> +%sum16_hi = pto.vadd %s16_2, %s16_3, %sum8_mask + : !pto.vreg<64xf32> + +pto.vsts %sum16_lo, %out16[%group_off16], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum16_hi, %out16[%group_off16_plus_8], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Rematerialized S=32 use. Two DINTLV loads plus one register deinterleave +// level produce mod-4 columns for rows 0..7. +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s32_0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s32_1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s32_2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s32_3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> + +%s32_01 = pto.vadd %s32_0, %s32_1, %sum8_mask : !pto.vreg<64xf32> +%s32_23 = pto.vadd %s32_2, %s32_3, %sum8_mask : !pto.vreg<64xf32> +%sum32_block = pto.vadd %s32_01, %s32_23, %sum8_mask : !pto.vreg<64xf32> + +pto.vsts %sum32_block, %out32[%group_off32], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out16[group_off16 + r] = + reduce(base[off + r * 16 + 0 .. off + r * 16 + 15]) + +for r = 0..7: + out32[group_off32 + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +Baseline assignment inserts `ensure_layout` at the mismatched use. A later +rematerialization pass may clone a cheap producer such as load and assign each +clone independently. If no deinterleaved=2 <-> deinterleaved=4 materialization +support exists, emit a layout-contract diagnostic naming both consumers and +both required layouts. +``` + +### 3.34 S=64 Group-Slot Result `f32 -> f16` Cast + +Section 3.13 rejects direct width-changing cast for packed `slots = 8` +group-slot values. This case is the positive counterpart for row-local +`slots = 1`: each group result is already lane 0 of its own physical vreg, so a +slot-preserving cast can lower one row-local result at a time. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 8, group_size = 64} + : index -> !pto.vmi.mask<512xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<512xf16, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%one_b16 = pto.pge_b16 "PAT_VL1" + +// The compiler emits this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum32_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Only lane 0 is semantic. EVEN keeps f32 lane 0 in f16 lane 0; all other +// lanes are non-semantic for group_slots(num_groups=8, slots=1). +%sum16_r = pto.vcvt %sum32_r, %one_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +pto.vsts %sum16_r, %out[%group_tile_off_r], %one_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + truncf(reduce(base[off + r * 64 + 0 .. off + r * 64 + 63])) +``` + +Required assignment rule: + +```text +Group-slot casts are layout-specific. `slots = 1` may use a slot-preserving +row-local cast because each semantic scalar is lane 0 of its own physical vreg. +This does not legalize packed `slots = 8` casts from section 3.13. +``` + +### 3.35 `group_slots` Fanout To `group_store` And `group_broadcast` + +This case fixes the fanout rule for group-slot values. A `group_slots` value may +feed multiple group-aware consumers directly. Layout assignment must not +materialize it as dense just because one later use broadcasts it. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask_for_reduce: + !pto.vmi.mask<128xb32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%x0 = pto.vlds %base[%tile_off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%tile_off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +// ensure_layout for the first group_reduce. +%x_lo, %x_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + +// First group-slot consumer: store the group slots without changing layout. +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Second group-slot consumer: materialize only this use as dense grouped data. +%broadcast_idx0 = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx1 = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> +%b0 = pto.vselr %sum_block, %broadcast_idx0 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b1 = pto.vselr %sum_block, %broadcast_idx1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y0 = pto.vmul %x0, %b0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b1, %all_b32 : !pto.vreg<64xf32> + +// ensure_layout for the second group_reduce. +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = s + out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) +``` + +Required assignment rule: + +```text +`%sum` keeps one assigned layout: + #pto.vmi.layout + +`group_store` consumes that group-slot layout directly. +`group_broadcast` is a use-site materialization to a dense layout. It must not +rewrite the defining `group_reduce` result or the sibling `group_store` use. +``` + +### 3.36 Same Scalar Source Materialized As `slots = 8` And `slots = 1` + +The same memory scalar stream may be used by both packed S=16 group-slot +compute and row-local S=64 group-slot compute. The two uses require different +logical vector shapes and different group-slot layouts, so the source must be +rematerialized as two VMI values. There is no single `group_slots` layout that +serves both uses. + +VMI input: + +```text +%rhs16 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.load %base16[%off16] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8} +%out16v = pto.vmi.addf %sum16, %rhs16 +pto.vmi.group_store %out16v, %out16[%group_off16], %c1 {num_groups = 8} + +%rhs64 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +%x64 = pto.vmi.load %base64[%off64] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum64 = pto.vmi.group_reduce_addf %x64, %mask64 {num_groups = 8} +%out64v = pto.vmi.addf %sum64, %rhs64 +pto.vmi.group_store %out64v, %out64[%group_off64], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%rhs16, %sum16, %out16v: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x16, %mask16: + #pto.vmi.layout + +%rhs64, %sum64, %out64v: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%x64, %mask64: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +// Packed S=16 RHS: one 32B scalar block in lanes 0..7. +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%rhs16_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// S=16 reduction is the section 3.5.1 shape. +%x16_lo, %x16_hi = pto.vldsx2 %base16[%tile_off16], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_lo = pto.vcgadd %x16_lo, PAT_ALL_B32 : !pto.vreg<64xf32> +%s16_hi = pto.vcgadd %x16_hi, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum16_block = pto.vadd %s16_lo, %s16_hi, %slot8 : !pto.vreg<64xf32> +%out16_block = pto.vadd %sum16_block, %rhs16_block, %slot8 + : !pto.vreg<64xf32> +pto.vsts %out16_block, %out16[%group_off16], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Row-local S=64 RHS: a separate group_slot_load op produces one lane-0 +// value per physical row-local result. +%rhs64_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Emit this row-local reduction/add/store shape for r = 0..7. +%x64_r = pto.vlds %base64[%row_off64_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p64_r = pto.vcgadd %x64_r, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum64_r = pto.vcadd %p64_r, PAT_VL8_B32 : !pto.vreg<64xf32> +%out64_r = pto.vadd %sum64_r, %rhs64_r, %one_b32 : !pto.vreg<64xf32> +pto.vsts %out64_r, %out64[%group_off64_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out16[group_off16 + r] = reduce(base16[row_r, 0..15]) + rhs_base[rhs_off + r] + out64[group_off64 + r] = reduce(base64[row_r, 0..63]) + rhs_base[rhs_off + r] +``` + +Required assignment rule: + +```text +`group_slot_load` is a memory op, so the baseline rematerialization pass must +not clone it as a generic cheap producer. If two use sites need different +`group_slots` layouts, the legal first-stage shape is to write two explicit +`group_slot_load` ops, as above, or to introduce a future load-cloning +optimization with an explicit memory-safety proof. Do not invent a common +layout or make `vmi-to-vpto` inspect both users. +``` + +### 3.37 S=64 `group_store` With Non-Unit Output Stride + +Packed `slots = 8` stores currently require unit output stride. Row-local +`slots = 1` does not have that restriction because each group scalar is stored +by a separate lane-0 store. + +VMI input: + +```text +%row_stride = arith.index_cast %ld : i64 to index +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %row_stride {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 : !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 : !pto.vreg<64xf32> + +%dst_r = %out + %group_off + r * %row_stride +pto.vsts %sum_r, %dst_r, %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * row_stride] = reduce(row_r[0..63]) +``` + +Required assignment rule: + +```text +If `group_store` has non-unit row_stride and the source can legally use +`slots = 1`, assignment may select `slots = 1` to keep the store legal. If the +source is fixed to `slots = 8`, current target support must diagnose unless a +strided packed store materializer exists. +``` + +### 3.38 Multi-Tile S=32 `group_reduce` + +The S=32 plan is not only a one-tile special case. For more than eight groups, +layout assignment keeps the same layout and `vmi-to-vpto` emits the same +8-row tile lowering sequence for each physical tile. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 16, group_size = 32} + : index -> !pto.vmi.mask<512xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 16} +``` + +Assigned layouts: + +```text +%x, %mask: + !pto.vmi.vreg<512xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// Emit this shape for tile t = 0 and tile t = 1. +// Each tile covers eight 32-f32 rows. +%tile_base_t = %base + %off + t * 256 +%tile_out_t = %out + %group_off + t * 8 + +%x_even_0_t, %x_odd_0_t = pto.vldsx2 %tile_base_t[%c0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1_t, %x_odd_1_t = pto.vldsx2 %tile_base_t[%c128], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0_t, %x_p2_t = pto.vdintlv %x_even_0_t, %x_even_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1_t, %x_p3_t = pto.vdintlv %x_odd_0_t, %x_odd_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0_t = pto.vcgadd %x_p0_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s1_t = pto.vcgadd %x_p1_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s2_t = pto.vcgadd %x_p2_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s3_t = pto.vcgadd %x_p3_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s01_t = pto.vadd %s0_t, %s1_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%s23_t = pto.vadd %s2_t, %s3_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%sum_block_t = pto.vadd %s01_t, %s23_t, PAT_VL8_B32 + : !pto.vreg<64xf32> + +pto.vsts %sum_block_t, %tile_out_t, PAT_VL8_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For `group_slots(num_groups = 16, slots = 8)`, the physical arity is +`num_groups / slots = 2`. The type conversion must expose two packed result +blocks in group order. `group_store` stores both blocks with offsets +`group_off + 0` and `group_off + 8`. +``` + +### 3.39 Strided S=32 `group_load` Through Broadcast And Second Reduce + +Section 3.27 covers strided S=32 `group_load -> group_reduce -> group_store`. +This case adds the missing dense continuation. The important layout fact is +that a strided block load naturally produces +`deinterleaved = 4, block_elems = 8`; `group_broadcast` must materialize the +broadcast into that same block-fragment layout when the broadcast feeds +elementwise compute and another S=32 group reduction. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask, %b, %y: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%stride_blocks = %c5_i16 // 40 f32 = 5 * 32B blocks. + +%x_p0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Materialize the same per-row scalar into every 32B row fragment. The four +// bundle entries have the same lane contents, but the result layout remains +// deinterleaved=4, block_elems=8 because the consumer `%y = mulf %x, %b` +// operates on the block-fragment layout. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 : !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 : !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 : !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 : !pto.vreg<64xf32> +%ys01 = pto.vadd %ys0, %ys1, %slot8 : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %slot8 : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(base[off + r * 40 + 0 .. off + r * 40 + 31]) + out[group_off + r] = + reduce_i(base[off + r * 40 + i] * s for i = 0..31) +``` + +Required assignment rule: + +```text +`block_elems` is part of dense layout compatibility. A broadcast result feeding +an elementwise op with `%x : deinterleaved=4, block_elems=8` must also be +assigned `deinterleaved=4, block_elems=8`. Reusing a +`deinterleaved=4, block_elems=1` broadcast would be a layout mismatch even +though both have four physical parts. +``` + +### 3.40 Scalar Broadcast Feeding Dense And Grouped Users + +This case fixes the rule for ordinary scalar broadcasts. A scalar broadcast is +not born with a physical layout. Baseline layout assignment assigns the +transfer-equivalent producer chain to the non-contiguous layout requested by the +grouped consumer and inserts an explicit materialization at the dense store use. +The later `vmi-layout-rematerialize` pass may replace that helper with a cloned +broadcast when profitable. + +VMI input: + +```text +%scale = pto.vmi.broadcast %scale_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%copy = pto.vmi.addf %x, %scale +pto.vmi.store %copy, %copy_out[%off] + +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%prod = pto.vmi.mulf %x, %scale +%sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %scale, %copy, %prod: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%copy_dense = pto.vmi.ensure_layout %copy: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// The shared load is assigned deinterleaved=4, block_elems=8 because the +// grouped consumer dominates the useful compute layout. +%x0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +// Dense store use: compute in deinterleaved=4, then ensure_layout materializes +// the contiguous memory order for the external effect. +%copy_p0 = pto.vadd %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%copy_p1 = pto.vadd %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%copy_p2 = pto.vadd %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%copy_p3 = pto.vadd %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%c01_lo, %c01_hi = pto.vintlv %copy_p0, %copy_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%c23_lo, %c23_hi = pto.vintlv %copy_p1, %copy_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy0, %copy1 = pto.vintlv %c01_lo, %c23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy2, %copy3 = pto.vintlv %c01_hi, %c23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %copy0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// Grouped use: reuse the same deinterleaved operands directly. +%prod_p0 = pto.vmul %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%prod_p1 = pto.vmul %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%prod_p2 = pto.vmul %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%prod_p3 = pto.vmul %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %prod_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %prod_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %prod_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %prod_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + scale_s + +for r = 0..7: + sum_out[group_off + r] = + reduce_i(base[off + r * 32 + i] * scale_s for i = 0..31) +``` + +Required assignment rule: + +```text +`broadcast` is layout-transparent and cheaply rematerializable by the optional +`vmi-layout-rematerialize` pass, but baseline assignment does not have to force +a separate contiguous broadcast just because a dense store exists. It may +choose a common deinterleaved compute layout for transfer-equivalent elementwise +ops and insert `ensure_layout` at the dense store. The required invariant is +that this choice is explicit in the assigned IR; `vmi-to-vpto` must not infer it +by inspecting both users. +``` + +### 3.41 Non-Rematerializable Value With Incompatible Users + +This is the non-cheap counterpart to section 3.18. A `masked_load` has explicit +mask and passthrough semantics, so layout assignment should not clone it as a +normal cheap load unless a dedicated rematerialization rule proves that clone +legal. The conflict is solved by inserting `ensure_layout` at one use site. + +VMI input: + +```text +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + +pto.vmi.store %x, %copy_out[%off] + +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %zero for masked_load/store: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask for masked_load/store: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%zero0 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero2 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero3 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%l1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%l2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%l3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x0 = pto.vsel %l0, %zero0, %all_b32 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %all_b32 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %zero2, %all_b32 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %zero3, %all_b32 : !pto.vreg<64xf32> + +pto.vsts %x0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %x1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %x2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %x3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// ensure_layout contiguous -> deinterleaved=4 at the reduce use. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For non-rematerializable producers, assignment must insert an explicit use-site +materialization helper, such as contiguous -> deinterleaved=4. If that helper +has no supported materialization, the layout gate must diagnose before +vmi-to-vpto. `vmi-to-vpto` must not clone the masked_load or choose a +materialization after seeing both users. +``` + +### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator + +Section 3.22 covers dense loop-carried values. Group-slot values need a +separate case because the loop-carried block argument has no dense lane +semantics outside the live group slots. + +VMI input: + +```text +%acc0 = pto.vmi.group_slot_load %init[%group_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + +%acc = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg = %acc0) -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.group_load %base[%tile_off_k], %c16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + %next = pto.vmi.addf %arg, %sum + scf.yield %next : !pto.vmi.vreg<128xf32> +} + +pto.vmi.group_store %acc, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%acc0, %arg, %sum, %next, %acc: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%acc0_block = pto.vsldb %init[%group_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%acc_block = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg_block = %acc0_block) -> !pto.vreg<64xf32> { + %lo, %hi = pto.vldsx2 %base[%tile_off_k], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %lo, %all_b32 : !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %hi, %all_b32 : !pto.vreg<64xf32> + %sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + %next_block = pto.vadd %arg_block, %sum_block, %slot8 : !pto.vreg<64xf32> + scf.yield %next_block : !pto.vreg<64xf32> +} + +pto.vsts %acc_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + init[group_off + r] + + sum_k reduce(base[tile_k, row_r, 0..15]) +``` + +Required assignment rule: + +```text +Loop-carried `group_slots` values are valid. The iter_arg, body block +argument, yield operand, loop result, and final group_store operand all carry +the same `group_slots(num_groups=8, slots=8)` layout. Ordinary dense consumers +inside the loop still require an explicit `group_broadcast` or diagnostic. +``` + +### 3.43 Internal Function Argument Boundary Materialization + +Section 3.25 covers a private function returning a VMI value. A callee argument +is the other direction of the same ABI problem: the callee body may require a +layout that is different from the layout naturally produced at a call site. + +The current implementation keeps the internal function VMI signature +contiguous and makes the callee-entry materialization explicit with +`ensure_layout` / `ensure_mask_layout`. This is less aggressive than +specializing the VMI function signature to `deinterleaved = 4`, but it preserves +the same invariant: after layout assignment, `vmi-to-vpto` lowers only from +explicit type and helper information and does not inspect the callee body while +lowering a call. + +VMI input: + +```text +func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, %group_off: index) { + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} + return +} + +func.func @caller(%base: !pto.ptr, %off: index, + %out: !pto.ptr, %group_off: index) { + %x = pto.vmi.load %base[%off] + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %group_off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return +} +``` + +Assigned layouts: + +```text +@consume argument %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +@consume argument %mask: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +inside @consume: + %x_split = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + + %mask_split = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +@caller %x and %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @consume(%x_p0: !pto.vreg<64xf32>, + %x_p1: !pto.vreg<64xf32>, + %x_p2: !pto.vreg<64xf32>, + %x_p3: !pto.vreg<64xf32>, + %m0: !pto.mask, + %m1: !pto.mask, + %m2: !pto.mask, + %m3: !pto.mask, + %out: !pto.ptr, + %group_off: index) { + // Callee-entry lowering of ensure_layout contiguous -> deinterleaved=4, + // block_elems=8. + %x01_lo, %x01_hi = pto.vdintlv %x_p0, %x_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x23_lo, %x23_hi = pto.vdintlv %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d0, %x_d2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d1, %x_d3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %m01_lo, %m01_hi = pto.pdintlv_b32 %m0, %m1 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m23_lo, %m23_hi = pto.pdintlv_b32 %m2, %m3 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d0, %m_d2 = pto.pdintlv_b32 %m01_lo, %m23_lo + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d1, %m_d3 = pto.pdintlv_b32 %m01_hi, %m23_hi + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + + %slot8 = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_d0, %m_d0 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_d1, %m_d1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_d2, %m_d2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_d3, %m_d3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return +} + +func.func @caller(...) { + // Caller keeps the load and group mask in the contiguous function ABI layout. + %x0 = pto.vlds %base[%off] : !pto.ptr -> !pto.vreg<64xf32> + %x1 = pto.vlds %base[%off_plus_64] : !pto.ptr -> !pto.vreg<64xf32> + %x2 = pto.vlds %base[%off_plus_128] : !pto.ptr -> !pto.vreg<64xf32> + %x3 = pto.vlds %base[%off_plus_192] : !pto.ptr -> !pto.vreg<64xf32> + + %m0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m1 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m2 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m3 = pto.pset_b32 "PAT_ALL" : !pto.mask + + call @consume(%x0, %x1, %x2, %x3, %m0, %m1, %m2, %m3, %out, %group_off) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.mask, !pto.mask, + !pto.mask, !pto.mask, !pto.ptr, index) -> () + return +} +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +Private function boundary layout is explicit in the assigned function type and +callee-entry helpers. The current endpoint chooses a contiguous VMI function +ABI and inserts callee-entry materialization for the grouped body requirement. +`vmi-to-vpto` does not inspect the callee body while lowering the call and does +not inspect callers while lowering the callee block argument. + +Future optimization may specialize private VMI function signatures directly to +`deinterleaved = 4, block_elems = 8` when all call sites agree. That +optimization must still be expressed in the assigned VMI function type before +`vmi-to-vpto` runs. +``` + +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-argument-boundary-store + +ptoas pipeline: + vmi-layout-assignment inserts explicit callee-entry materialization + vmi-to-vpto physicalizes the call operands and callee body + ptoas then inlines the private physical helper before VPTO vecscope/backend + emission, so the backend never needs a physical VPTO vector function ABI +``` + +### 3.44 `masked_load` Grouped Tail Feeding S=32 Reduce + +This case connects the explicit `masked_load` tail model from section 3.30 with +grouped reduction. The load has no padding constant hidden in the op; inactive +lanes are provided by the passthrough value and excluded from the reduction by +the same grouped mask. + +VMI input: + +```text +%c25 = arith.constant 25 : index +%mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%zero, %x for masked_load: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce: + pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + -> !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Lowering: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// masked_load direct lowering stays contiguous. +%m0, %m1, %m2, %m3 = materialize contiguous create_group_mask(c25, S=32) +%z0, %z1, %z2, %z3 = vdup zero +%l0 = pto.vlds %base[%off] +%l1 = pto.vlds %base[%off_plus_64] +%l2 = pto.vlds %base[%off_plus_128] +%l3 = pto.vlds %base[%off_plus_192] +%x0 = pto.vsel %l0, %z0, %m0 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %z1, %m1 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %z2, %m2 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %z3, %m3 : !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=4, block_elems=8. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + +// The reduce-side grouped mask is not built by guessing the final group-slot +// predicate image. It is first materialized as the same contiguous grouped +// mask used by masked_load, then converted to the reduce layout with predicate +// deinterleave. This keeps predicate reordering identical to the data +// reordering above. +%rm0, %rm1, %rm2, %rm3 = materialize contiguous create_group_mask(c25, S=32) +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi + +%s0 = pto.vcgadd %x_p0, %mask_p0 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %mask_p1 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %mask_p2 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %mask_p3 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 24]) +``` + +Required assignment rule: + +`masked_load` and `group_reduce` must share the same grouped mask layout. The +passthrough value defines inactive loaded lanes, while the reduce mask defines +participation. Assignment materializes two explicit mask values when needed: +one contiguous value for `masked_load`, and one deinterleaved value for +`group_reduce_addf`. `vmi-to-vpto` lowers the deinterleaved +`create_group_mask` by materializing the contiguous grouped predicate chunks +and then applying `pdintlv_b32` in the same tree shape as the data +`vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to +choose or reject the support path. + +Assignment may select a deinterleaved S=32 load layout only when the rounded +physical reads are memory-safe; otherwise it must diagnose or use a future +stable gather fallback. + +Runtime coverage: + +```text +test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` + +### 3.45 Dynamic S=32 `create_group_mask` + +This is the dynamic-shape form of section 3.44. The active column count is an +SSA `index`, not a constant. The semantic mask is still grouped: + +```text +lane i active iff (i % 32) < active_cols +``` + +VMI input: + +```text +%mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%mask for S=32 group_reduce: + !pto.vmi.mask<256xb32, + #pto.vmi.layout> +``` + +Contiguous VPTO lowering for one b32 physical chunk: + +```text +%active_i32 = arith.index_cast %active_cols : index to i32 +%active_nonneg = arith.maxsi %active_i32, %c0_i32 : i32 +%active_clamped = arith.minui %active_nonneg, %c32_i32 : i32 + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row_base = pto.vshls %row, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row_base, %all + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask + -> !pto.vreg<64xi32> +%m = pto.vcmps %col, %active_clamped, %all, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask +``` + +For `deinterleaved = 4, block_elems = 8`, lowering first emits four contiguous +chunks with the sequence above, then applies the same predicate deinterleave +tree used by section 3.44: + +```text +%rm0, %rm1, %rm2, %rm3 = dynamic contiguous grouped masks +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi +``` + +Current coverage validates both IR lowering and runtime behavior: + +```text +test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store +``` + +The runtime case passes `active_cols` as a kernel scalar argument and casts it +to `index` inside `pto.vecscope`. This keeps scalar materialization outside +`vmi-to-vpto`; the lowering pass only consumes the current +`create_group_mask` operand. + +### 3.46 `extf` Value And Derived Elementwise Value Both Stored + +This case fixes where contiguous materialization belongs when one widened value +is used directly by a store and also by a layout-transparent elementwise chain +that is stored. + +VMI input: + +```text +%a = pto.vmi.load %in[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%k = pto.vmi.broadcast %k1 + : f32 -> !pto.vmi.vreg<128xf32> + +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%t1 = pto.vmi.mulf %w, %k + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Hard-legalized assigned layouts: + +```text +%a: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%w, %k, %t1: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%t1_c = pto.vmi.ensure_layout %t1: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %t1_c, %out1[%off] + +%w_c = pto.vmi.ensure_layout %w: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %w_c, %out2[%off] +``` + +Baseline VPTO lowering result: + +```text +%a0 = pto.vlds %in[%off] {dist = "NORM"} + : !pto.ptr, index -> !pto.vreg<128xf16> + +%w_p0 = pto.vcvt %a0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%w_p1 = pto.vcvt %a0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%k_p0 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> +%k_p1 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%t1_p0 = pto.vmul %w_p0, %k_p0, PAT_ALL_B32 : !pto.vreg<64xf32> +%t1_p1 = pto.vmul %w_p1, %k_p1, PAT_ALL_B32 : !pto.vreg<64xf32> + +// ensure_layout for the first store. +%t1_0, %t1_1 = pto.vintlv %t1_p0, %t1_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %t1_0, %out1[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %t1_1, %out1[%off_plus_64], %all_b32 {dist = "NORM_B32"} + +// ensure_layout for the second store. +%w_0, %w_1 = pto.vintlv %w_p0, %w_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %w_0, %out2[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %w_1, %out2[%off_plus_64], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for i = 0..127: + out1[off + i] = f32(in[off + i]) * k1 + out2[off + i] = f32(in[off + i]) +``` + +Optimization pass result: + +```text +// vmi-layout-fold may remove both ensure_layout ops if the target +// supports store lowering that consumes deinterleaved=2 and writes contiguous +// row-major memory. +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Optimized VPTO lowering result: + +```text +pto.vstsx2 %t1_p0, %t1_p1, %out1[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + +pto.vstsx2 %w_p0, %w_p1, %out2[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Required assignment and optimization rule: + +```text +Hard legalization may always preserve `%w` and `%t1` in deinterleaved=2 and +insert use-site ensure_layout before ordinary stores. This is correct because +the layout change is explicit at the store use. + +Consumer folding is optional. It may remove the ensure_layout only when the +store itself can locally prove the same contiguous memory effect from the +source layout. vmi-to-vpto must not scan the `%w` producer or both store users +to decide this. +``` + +### 3.47 Type-Parametric Group Reduce Rule + +The group-reduce layout rule is parameterized by the element width, not by f32 +case names. + +```text +E = sizeof(T) +VLaneElems = 32B / E +L = 256B / E +S = logical_lane_count / num_groups +``` + +The canonical grouped-reduce layouts are: + +```text +Packed group-slot rule: + K is the physical slot capacity of one packed group-result chunk. + For VCG-style packed reductions, K = 8. + G does not have to be divisible by K; the final chunk may be partial. + active_groups(chunk c) = min(K, G - c * K). + +S == VLaneElems: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=8) + +S == 2 * VLaneElems: + source/mask layout = deinterleaved=2 + result layout = group_slots(num_groups=G, slots=8) + +S == 4 * VLaneElems: + source/mask layout = deinterleaved=4 + result layout = group_slots(num_groups=G, slots=8) + +S >= L && S % L == 0: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=1) +``` + +Concrete shape table: + +```text +T VLaneElems L packed cases row-local cases +f32 8 64 S=8, S=16, S=32 S=64, S=128, ... +i32 8 64 S=8, S=16, S=32 S=64, S=128, ... +f16 16 128 S=16, S=32, S=64 S=128, S=256, ... +i16 16 128 S=16, S=32, S=64 S=128, S=256, ... +f8 32 256 cast to f32 before grouped reduce +i8 32 256 cast to i16/i32 before grouped reduce +``` + +These non-f32 cases are part of the type-generic layout/lowering design. If a +typed reduce op admits the element type and the target capability registry +accepts it, assignment must use the same `VLaneElems/L/S` formula instead of +adding per-type shape special cases. Any f32-only behavior in the current +implementation is staged implementation status, not the intended design limit. +For the current baseline, `f8/i8` are storage and cast-boundary types: they are +valid as load/store element types and as cast source/destination, but compute +ops such as group reduce consume the post-cast accumulator type. + +### 3.48 16-bit Typed Group Reduce, `S = VLaneElems = 16` + +This case covers both `f16` and `i16`. The element width is the same, so the +layout and VPTO instruction skeleton are identical. The VMI op name carries the +semantic difference: + +```text +f16: pto.vmi.group_reduce_addf ... {reassoc} +i16 storage: pto.vmi.extsi/extui ... -> i32 group_reduce_addi ... +``` + +VMI-shaped input: + +```text +// Floating form. +%xf = pto.vmi.load %base_f16[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%mf = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumf = pto.vmi.group_reduce_addf %xf, %mf {num_groups = 8, reassoc} +pto.vmi.group_store %sumf, %out_f16[%group_off], %c1 {num_groups = 8} + +// Integer form. +%xi = pto.vmi.load %base_i16[%off] + : memref<128xi16> -> !pto.vmi.vreg<128xi16> +%mi = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumi = pto.vmi.group_reduce_addi %xi, %mi {num_groups = 8} +pto.vmi.group_store %sumi, %out_i16[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%xf, %mf, %xi, %mi: + #pto.vmi.layout + +%sumf: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%sumi: + !pto.vmi.vreg<128xi16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%sum0 = pto.vcgadd %x0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 16 + 0 .. 15]) +``` + +### 3.49 16-bit Typed Group Reduce, `S = 2 * VLaneElems = 32` + +This case covers both `f16` and `i16`. Each logical row is 64B and must be +split into two 32B VLane fragments before `vcgadd`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xT16> -> !pto.vmi.vreg<256xT16> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B16" + : !pto.ptr, index -> !pto.vreg<128xT16>, !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum0 = pto.vadd %s0, %s1, %slot8_b16 + : !pto.vreg<128xT16>, !pto.vreg<128xT16>, !pto.mask + -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 32 + 0 .. 31]) +``` + +### 3.50 16-bit Typed Group Reduce, `S = 4 * VLaneElems = 64` + +This is the four-fragment packed case for both `f16` and `i16`. + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<512xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4 input + : four !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 : !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 : !pto.vreg<128xT16> +%s2 = pto.vcgadd %x_p2, %all_b16 : !pto.vreg<128xT16> +%s3 = pto.vcgadd %x_p3, %all_b16 : !pto.vreg<128xT16> + +%s01 = pto.vadd %s0, %s1, %slot8_b16 : !pto.vreg<128xT16> +%s23 = pto.vadd %s2, %s3, %slot8_b16 : !pto.vreg<128xT16> +%sum0 = pto.vadd %s01, %s23, %slot8_b16 : !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 64 + 0 .. 63]) +``` + +#### 3.50.1 Partial Packed `S = 64` Reductions + +This is the same `S = 4 * VLaneElems` lowering family as section 3.50, but it +covers `G` values that do not fill every packed group-result chunk. The key +point is that `slots = 8` is a physical capacity, not a promise that every +chunk contains eight valid group results. + +The result layout remains: + +```text +!pto.vmi.vreg<(G * 64)xf16, #pto.vmi.layout> +``` + +The lowering computes per result chunk: + +```text +K = 8 +chunk c active groups A(c) = min(K, G - c * K) + +source active lanes per deinterleaved part for chunk c: + A(c) * VLaneElems = A(c) * 16 f16 lanes + +reduce input mask: + PAT_VL(A(c) * 16) + +combine/store mask: + PAT_VL(A(c)) +``` + +For full chunks, `A(c) = 8`, so the reduce input mask is `PAT_ALL` for f16 +and the combine/store mask is `PAT_VL8`. For partial chunks, masks are +required for correctness. The semantic source mask produced by +`pto.vmi.create_group_mask` must also materialize only the valid source lanes; +the reduce lowering should not treat padding lanes as active data. + +##### `G = 4`: `256xf16, num_groups = 4` + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf16> -> !pto.vmi.vreg<256xf16> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 4, group_size = 64} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 4, reassoc} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 4} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf16, #pto.vmi.layout> +``` + +VPTO lowering shape for the only result chunk: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4, block_elems=8 input + : four !pto.vreg<128xf16> + +%lane64_b16 = pto.pge_b16 "PAT_VL64" // A * 16 = 4 * 16 +%slot4_b16 = pto.pge_b16 "PAT_VL4" + +%s0 = pto.vcgadd %x_p0, %lane64_b16 : !pto.vreg<128xf16> +%s1 = pto.vcgadd %x_p1, %lane64_b16 : !pto.vreg<128xf16> +%s2 = pto.vcgadd %x_p2, %lane64_b16 : !pto.vreg<128xf16> +%s3 = pto.vcgadd %x_p3, %lane64_b16 : !pto.vreg<128xf16> + +%s01 = pto.vadd %s0, %s1, %slot4_b16 : !pto.vreg<128xf16> +%s23 = pto.vadd %s2, %s3, %slot4_b16 : !pto.vreg<128xf16> +%sum0 = pto.vadd %s01, %s23, %slot4_b16 : !pto.vreg<128xf16> + +pto.vsts %sum0, %out[%group_off], %slot4_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..3: + out[group_off + r] = reduce_f16(base[off + r * 64 + 0 .. 63]) + +sum0 lanes 4..127 are not semantic for this VMI result. +``` + +##### `G = 8`: full packed chunk + +This is section 3.50. There is one result chunk with `A = 8`: + +```text +source mask = PAT_ALL // 8 * 16 = 128 f16 lanes +combine/store = PAT_VL8 +result layout = group_slots(num_groups=8, slots=8) +``` + +##### `G = 12`: full chunk plus partial chunk + +This case needs two packed result chunks: + +```text +result layout = group_slots(num_groups=12, slots=8) +result arity = ceil(12 / 8) = 2 +``` + +Chunk 0 handles groups `0..7`: + +```text +A(0) = 8 +source mask = PAT_ALL +combine/store = PAT_VL8 +``` + +Chunk 1 handles groups `8..11`: + +```text +A(1) = 4 +source mask = PAT_VL64 +combine/store = PAT_VL4 +``` + +Implementation checklist for this family: + +```text +layout attr: + slots=8 should be legal even when num_groups is not divisible by 8. + slot_block(g) = g / 8 and slot_lane(g) = g % 8 are still well-defined. + +layout assignment: + packed VCG-style group_reduce results keep slots=8. + +mask materialization: + create_group_mask must not activate padding lanes in partial chunks. + For chunk c, source active lanes are A(c) * VLaneElems. + +vmi-to-vpto group_reduce: + use A(c) from result layout slots and num_groups. + combine masks use PAT_VL(A(c)). + input vcgadd consumes the physical mask parts, which must already encode + PAT_VL(A(c) * VLaneElems) for all-true grouped masks. + +vmi-to-vpto group_store: + use A(c) to build the store predicate. + output group offset for chunk c is c * slots. +``` + +### 3.51 16-bit Typed Group Reduce, `S = L = 128` + +This is the first row-local full-physical-chunk case for both `f16` and `i16`. +The canonical result is row-local `slots = 1`, not packed `slots = 8`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<1024xT16> -> !pto.vmi.vreg<1024xT16> +%mask = pto.vmi.create_group_mask %c128 {num_groups = 8, group_size = 128} + : index -> !pto.vmi.mask<1024xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<1024xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" +%slot1_b16 = pto.pge_b16 "PAT_VL1" + +// Repeated for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> +%partial_r = pto.vcgadd %x_r, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum_r = pto.vcadd %partial_r, %slot8_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum_r, %out[%group_off_plus_r], %slot1_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 128 + 0 .. 127]) +``` + +### 3.52 32-bit Typed Group Reduce + +This case covers both `f32` and `i32`. The element width is the same, so +`VLaneElems = 8` and `L = 64` for both. Floating-point uses +`group_reduce_addf` with `reassoc`; integer uses `group_reduce_addi`. + +Example for `S = 2 * VLaneElems = 16`: + +```text +%x: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xT32>, !pto.vreg<64xT32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%sum0 = pto.vadd %s0, %s1, %slot8_b32 + : !pto.vreg<64xT32>, !pto.vreg<64xT32>, !pto.mask + -> !pto.vreg<64xT32> + +pto.vsts %sum0, %out[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xT32>, !pto.ptr, !pto.mask +``` + +The same formula gives: + +```text +S=8: + contiguous, slots=8, one vcgadd. + +S=32: + deinterleaved=4, slots=8, four vcgadd plus vadd tree. + +S=64: + contiguous, slots=1, row-local vcgadd plus vcadd. + +S=128: + contiguous, slots=1, row-local multi-chunk accumulation. +``` + +### 3.53 Integer Semantics And Invalid Typed Reductions + +Integer group reduction is not a variant of `group_reduce_addf`; it requires a +typed integer op: + +```text +%sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = G} +``` + +Required semantics: + +```text +inactive lanes contribute integer zero +addition uses the target's normal integer add behavior +wrap/saturating variants must be represented by distinct ops if both are needed +signedness does not affect add, but does affect future max/min integer reduces +``` + +Required invalid cases: + +```text +pto.vmi.group_reduce_addf with integer element type -> verifier error +pto.vmi.group_reduce_addi with floating-point element type -> verifier error +pto.vmi.group_reduce_addi i8 -> invalid direct 8-bit accumulator reduce; + cast to i16/i32 first unless target exposes i8 vcgadd +S not in {VLaneElems, 2*VLaneElems, 4*VLaneElems} and not a full-chunk multiple + -> layout-contract diagnostic +``` + +### 3.54 8-bit Floating Group Reduce + +There is no direct f8 `vcgadd` grouped reduction in the current target model, +but f8 supports cast to an accumulator type. The semantic path is: + +```text +f8 storage -> cast/extf to f32 accumulator -> group_reduce_addf on f32 +``` + +Here `f8` is only the cast source and the memory element type. The reduction +itself is a f32 accumulator operation. + +The group size remains a logical-lane property. For example, reducing eight +rows of 32 f8 elements produces the same logical result as reducing eight rows +of 32 f32 accumulator elements after extension. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_f8[%off] + : memref<256xf8> -> !pto.vmi.vreg<256xf8> +%x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %out_f32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x8_packed = pto.vlds %base_f8[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%all_b8 = pto.pge_b8 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%x32_p0 = pto.vcvt %x8_packed, %all_b8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_packed, %all_b8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_packed, %all_b8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_packed, %all_b8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xf32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xf32> + +pto.vsts %sum0, %out_f32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_f32[group_off + r] = + reduce_f32(f32(base_f8[off + r * 32 + 0 .. 31])) +``` + +Direct f8 grouped reduction is invalid: + +```text +pto.vmi.group_reduce_addf %x8, %mask + : !pto.vmi.vreg<256xf8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +### 3.55 8-bit Integer Group Reduce + +The current target model has no i8 `vcgadd`. It does have widening `vcadd` for +full-vector reductions, but grouped reduction needs one partial result per +32B VLane. Since 8-bit integers support cast to wider integer types, the +baseline grouped path casts before reducing: + +```text +i8/i16 storage -> signed/unsigned cast to i32 accumulator + -> group_reduce_addi on the accumulator type +``` + +Here `i8`/`i16` are only cast sources and memory element types. The reduction +itself is an i32 accumulator operation, with signedness handled by the cast. + +The integer cast operation must carry signedness. This document uses +`extsi/extui` as the widening spelling and `trunci` as the narrowing spelling: + +```text +%x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x32 = pto.vmi.extui %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x8 = pto.vmi.trunci %x32 : !pto.vmi.vreg -> !pto.vmi.vreg +``` + +The last form is unsigned i8 on the current VPTO target: VISA exposes +VCVTII.s322u8/u322u8 for 32-bit to 8-bit narrowing, not a signed-i8 +destination form. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_i8[%off] + : memref<256xi8> -> !pto.vmi.vreg<256xi8> +%x32 = pto.vmi.extsi %x8 + : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out_i32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xi8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> +``` + +VPTO lowering shape after integer cast materialization: + +```text +%x32_p0, %x32_p1, %x32_p2, %x32_p3 = + materialize signed cast i8 -> i32 with deinterleaved=4 layout + : four !pto.vreg<64xi32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xi32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xi32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xi32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xi32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xi32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xi32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xi32> + +pto.vsts %sum0, %out_i32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_i32[group_off + r] = + reduce_i32(sign_extend(base_i8[off + r * 32 + 0 .. 31])) +``` + +Direct i8 grouped reduction without the cast is invalid: + +```text +pto.vmi.group_reduce_addi %x8, %mask + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +An optimized row-local i8 full-chunk lowering path may be added later for +`S = 256` by using widening `vcadd`, but that requires a widening +`group_slots` result contract and must not change the baseline cast-to-accumulator +semantics above. + +If the final memory result is i8, narrowing is a separate cast after the +accumulator computation: + +```text +%sum32 = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +%sum8 = pto.vmi.trunci %sum32 +pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} +``` + +That packed group-slot `trunci` path is not baseline lowering support yet; the +implementation must either define slot-wise VCVTII lowering support or diagnose at +layout assignment. + +### 3.56 Full 256-Bin Distribution Histogram + +Histogram is not modeled as `group_reduce`. A group reduce maps source lanes to +result slots by lane/group position. A histogram maps each active source lane +to a result bin by the source value itself. + +VMI-shaped input: + +```text +%src = pto.vmi.load %src_base[%src_off] + : memref -> !pto.vmi.vreg +%mask = pto.vmi.create_mask %active_lanes + : index -> !pto.vmi.mask +%acc = pto.vmi.load %acc_base[%acc_off] + : memref<256xui16> -> !pto.vmi.vreg<256xui16> +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +pto.vmi.store %hist, %out[%out_off] +``` + +Logical semantics: + +```text +for b = 0..255: + hist[b] = acc[b] + +for i = 0..N-1: + if mask[i]: + hist[src[i]] += 1 +``` + +Assigned layouts: + +```text +%src: + !pto.vmi.vreg> + +%mask: + !pto.vmi.mask> + +%acc, %hist: + !pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +The `256xui16` accumulator/result is one logical VMI value but two physical +VPTO vector registers: + +```text +physical result part0 = logical bins 0..127 +physical result part1 = logical bins 128..255 +``` + +For `N = 256`, VPTO lowering shape: + +```text +%src0 = pto.vlds %src_base[%src_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xui8> + +%acc_lo = pto.vlds %acc_base[%acc_off + 0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> +%acc_hi = pto.vlds %acc_base[%acc_off + 128] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> + +%hist_lo = pto.dhistv2 %acc_lo, %src0, %mask0, %bin0 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> +%hist_hi = pto.dhistv2 %acc_hi, %src0, %mask0, %bin1 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> + +pto.vsts %hist_lo, %out[%out_off + 0], %all_b16 {dist = "NORM_B16"} +pto.vsts %hist_hi, %out[%out_off + 128], %all_b16 {dist = "NORM_B16"} +``` + +Memory result: + +```text +for b = 0..127: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) + +for b = 128..255: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) +``` + +For `N > 256`, the source is processed in contiguous 256-lane chunks. The two +histogram accumulator parts are carried through all chunks: + +```text +%lo = %acc_lo +%hi = %acc_hi + +for source chunk c in logical order: + %chunk_mask = mask chunk c + if c is the final partial chunk: + %chunk_mask = %chunk_mask & valid-lane-prefix-for-this-chunk + + %lo = pto.dhistv2 %lo, %src_c, %chunk_mask, %bin0 + %hi = pto.dhistv2 %hi, %src_c, %chunk_mask, %bin1 + +result physical parts = [%lo, %hi] +``` + +Tail source lanes are expressed only through the b8 mask. Padding lanes in the +last physical source chunk must be masked off before `pto.dhistv2`; they are +not padding values. + +The VMI op does not expose `#bin`. `#bin` is a VPTO range selector forced by +the physical result width: + +```text +ui8 value domain = 256 bins +complete histogram = 256 x ui16 = 512B +one VPTO vreg result = 128 x ui16 = 256B +``` + +Therefore VMI represents one logical `256xui16` result and `vmi-to-vpto` +locally emits the low-range and high-range VPTO histogram updates. + +### 3.57 Full 256-Bin Cumulative Histogram + +The desired VMI surface shape mirrors `dhist`: + +```text +%hist = pto.vmi.chist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +``` + +The intended logical semantics is a full cumulative histogram: + +```text +dist[b] = count(i where mask[i] && src[i] == b) + +hist[0] = acc[0] + dist[0] +for b = 1..255: + hist[b] = acc[b] + dist[0] + dist[1] + ... + dist[b] +``` + +The current VPTO/VISA documentation only states that `CHISTv2` computes a +`uint16 Cumulative histogram` over the selected bin range. It does not state +whether the high-range call with `#bin = 1` returns: + +```text +global cumulative: + result[j] = count(src <= 128 + j) + +or range-local cumulative: + result[j] = count(128 <= src <= 128 + j) +``` + +These two interpretations have different VMI lowerings. If the hardware result +is global cumulative, the full VMI lowering is the same low/high split as +`dhist`, replacing `pto.dhistv2` with `pto.chistv2`. If the hardware result is +range-local cumulative, the high half also needs the total low-half count added +to every high-half bin: + +```text +%lo = pto.chistv2 %acc_lo, %src0, %mask0, %bin0 +%hi_local = pto.chistv2 %acc_hi, %src0, %mask0, %bin1 + +%low_total = materialize count(src <= 127) from the low-half result +%low_total_vec = broadcast %low_total to every high-half bin +%hi = pto.vadd %hi_local, %low_total_vec, %all_b16 +``` + +That correction path also requires a designed way to materialize and broadcast +the low-half total. Since baseline VMI does not support arbitrary vector +extract, the range-local CHISTv2 interpretation remains unsupported until that +materialization path is explicit. + +The baseline design therefore treats `pto.vmi.chist` as a semantic op whose +exact lowering is gated by a target semantic capability: + +```text +if target documents or validation proves CHISTv2 high range is global: + lower as two pto.chistv2 calls +elif target documents or validation proves CHISTv2 high range is range-local: + lower as pto.chistv2 low/high plus explicit high-half correction only after + low-total materialization support is designed +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +This boundary is deliberate. `pto.vmi.dhist` is fully defined because +distribution bins are independent across the low/high split. `pto.vmi.chist` +has cross-range prefix semantics, so VMI must not guess the high-half behavior +from the VPTO op name alone. diff --git a/docs/designs/vmi-layout-relation-rematerialization-design.md b/docs/designs/vmi-layout-relation-rematerialization-design.md new file mode 100644 index 0000000000..8e59b91b65 --- /dev/null +++ b/docs/designs/vmi-layout-relation-rematerialization-design.md @@ -0,0 +1,236 @@ +# VMI Layout Relation-Aware Rematerialization Design + +本文描述 VMI layout optimization 中 relation-aware rematerialization 的设计。 +目标是让 `vmi-layout-assignment` 只产生 legal baseline IR,把跨 layout +relation 的优化放到显式 `ensure_layout` 上完成。 + +## 1. Motivation + +`vmi-layout-assignment` 已经负责三件 hard legalization 工作: + +```text +1. 为每个 VMI value 选择 concrete layout +2. 在不匹配的 use-site 插入 ensure_layout / ensure_mask_layout +3. 保证 vmi-to-vpto 只需要 local lowering information +``` + +对 `ext` 这类 width-changing op,assignment 的 baseline 可以保守选择: + +```text +ext f16 -> f32: + source = contiguous + result = deinterleaved=2 +``` + +如果下游 `truncf f32 -> f8` 要求 source 为 `deinterleaved=4`,assignment 会 +显式插入: + +```text +%e = pto.vmi.extf %x + : !vreg<..., layout> + -> !vreg<..., layout> + +%e4 = pto.vmi.ensure_layout %e + : !vreg<..., layout> + -> !vreg<..., layout> +``` + +这个 IR 已经合法,但不是最优。优化 pass 可以从显式 helper 出发,把 relation +应用到 producer: + +```text +ensure_layout(ext(src), resultLayout) + => ext(ensure_layout(src, derivedSourceLayout)) +``` + +这样 assignment 不需要做 consumer-driven global propagation,也不需要在多 +consumer 冲突时引入 cost model。 + +## 2. Goals + +```text +1. assignment 保持 hard legalization baseline,不做 ext relation propagation。 +2. relation-aware optimization 从显式 ensure_layout 出发。 +3. 多 consumer 冲突由 use-site helper + rematerialization 解决。 +4. vmi-to-vpto 仍只消费当前 op 的 operand/result layout,不扫描上下文。 +5. 变换必须是局部、确定、可验证的 IR rewrite。 +``` + +非目标: + +```text +1. 不做 ComputeY1 专用 pattern。 +2. 不在 assignment 中实现全局 cost model。 +3. 不通过 vmi-to-vpto 猜 producer/consumer relation。 +4. 第一阶段不做 trunc/narrow relation remat。 +``` + +## 3. Optimization Model + +relation-aware remat 以 `ensure_layout` 为唯一触发点: + +```text +%wanted = pto.vmi.ensure_layout %source : sourceLayout -> targetLayout +``` + +如果 `%source` 的 producer 可以在 `targetLayout` 或 relation 派生出的 operand +layout 下重新创建等价结果,则用 cloned producer 替换 helper。 + +### 3.1 Layout-Transparent Producer Remat + +对 layout-transparent elementwise op: + +```text +ensure_layout(op(a, b), L) + => op(ensure_layout(a, L), ensure_layout(b, L)) +``` + +适用对象包括纯 elementwise data ops: + +```text +addf/addi/subf/subi/mulf/muli/divf/minf/maxf +andi/ori/xori/shli/shrui +negf/absf/absi/sqrt/exp/ln/relu/not +fma +select, when data operands and mask layout requirements can be kept explicit +``` + +第一阶段可以先覆盖 ComputeY1 需要的 `mulf`,但实现形态应按 op family 泛化。 + +### 3.2 Widen Ext Relation Remat + +对 widening `ext`: + +```text +ensure_layout(ext(src), resultLayout) + => ext(ensure_layout(src, sourceLayout)) +``` + +其中: + +```text +resultFactor = sourceFactor * widenFactor +``` + +例子: + +```text +ext f16 -> f32, widenFactor = 2 +target result layout = deinterleaved=4 +derived source layout = deinterleaved=2 +``` + +`deinterleaved=1` 等价于 contiguous。 + +### 3.3 Producer Fold After Remat + +relation remat 可能暴露 producer-side helper: + +```text +ensure_layout(load(...), deinterleaved=2) +``` + +这类 helper 应由 `vmi-layout-fold` 吸收到 producer 或 consumer: + +```text +load contiguous + ensure_layout to deinterleaved=2 + => load result deinterleaved=2 +``` + +因此推荐优化 pipeline 在 remat 后再次运行 fold: + +```text +vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse +``` + +## 4. Multi-Consumer Conflict + +如果一个 `ext` result 有两个 consumer: + +```text +consumer A requires deinterleaved=2 +consumer B requires deinterleaved=4 +``` + +assignment 不需要判断哪个更优。它可以选择稳定 baseline,例如 `deinterleaved=2`, +并为另一个 use 插入 helper: + +```text +%e2 = pto.vmi.extf %x : contiguous -> deinterleaved=2 +consumer_a(%e2) + +%e4 = pto.vmi.ensure_layout %e2 : deinterleaved=2 -> deinterleaved=4 +consumer_b(%e4) +``` + +remat 再把第二个 use 优化成 cloned producer: + +```text +%x2 = pto.vmi.ensure_layout %x : contiguous -> deinterleaved=2 +%e4 = pto.vmi.extf %x2 : deinterleaved=2 -> deinterleaved=4 +consumer_b(%e4) +``` + +原 `%e2` 仍服务 `consumer_a`。这样不需要 assignment 做全局 cost selection。 + +## 5. ComputeY1 Shape + +baseline assignment 可能产生: + +```text +%x32 = extf %x16 // result deinterleaved=2 +%s32 = extf %scale16 // result deinterleaved=2 +%m = mulf %x32, %s32 // result deinterleaved=2 +%m4 = ensure_layout %m // deinterleaved=2 -> deinterleaved=4 +%y = truncf %m4 +``` + +remat/fold 后目标 IR: + +```text +%x16_d2 = load ... // folded deinterleaved=2 load +%x32_d4 = extf %x16_d2 // deinterleaved=2 -> deinterleaved=4 + +%scale16_d2 = group_broadcast_load ... // folded/assigned deinterleaved=2 +%scale32_d4 = extf %scale16_d2 // deinterleaved=2 -> deinterleaved=4 + +%m4 = mulf %x32_d4, %scale32_d4 +%y = truncf %m4 +``` + +关键点: + +```text +1. truncf 只通过 ensure_layout 表达自己的 source layout requirement。 +2. remat 不需要识别 quant 语义。 +3. ext relation 是 local rule。 +4. load/group_broadcast_load 的物理优化由 fold 或 producer capability 处理。 +``` + +## 6. Lowering Contract + +`vmi-to-vpto` 的 contract 不变: + +```text +1. 不扫描 ext 的 users。 +2. 不扫描 producer chain 来猜 layout。 +3. 只根据当前 op 的 operand/result layout lower。 +``` + +relation-aware remat 必须在 `vmi-to-vpto` 前把 IR 显式改写为: + +```text +%x = pto.vmi.load ... -> !vreg<..., layout> +%e = pto.vmi.extf %x + : !vreg<..., layout> + -> !vreg<..., layout> +``` + +之后 lowering 只消费这个 local shape。 diff --git a/docs/designs/vmi-layout-relation-rematerialization-implementation.md b/docs/designs/vmi-layout-relation-rematerialization-implementation.md new file mode 100644 index 0000000000..540d7bc5e2 --- /dev/null +++ b/docs/designs/vmi-layout-relation-rematerialization-implementation.md @@ -0,0 +1,406 @@ +# VMI Layout Relation-Aware Rematerialization Implementation Plan + +本文是 `vmi-layout-relation-rematerialization-design.md` 的实现计划。目标是 +扩展现有 `vmi-layout-rematerialize` / `vmi-layout-fold` 优化,让 assignment +保持 legal baseline,并从显式 `ensure_layout` 中恢复更好的 producer layout。 + +## 1. Current Baseline + +当前 pipeline 中相关 pass: + +```text +vmi-layout-assignment: + chooses concrete layouts + inserts ensure_layout / ensure_mask_layout / ensure_mask_granularity + +vmi-layout-fold: + folds selected ensure_layout helpers into layout-aware producers/consumers + current coverage includes store-side fold, load -> ensure_layout producer + fold, and inverse nested ensure_layout fold + +vmi-layout-rematerialize: + replaces ensure_* around cheap construction producers + current data coverage: splat constant, broadcast, iota + current mask coverage: create_mask, create_group_mask, constant_mask + +vmi-layout-sink-materialization: + sinks matching operand-side helpers through pure elementwise ops + it does not currently rewrite result-side ensure_layout(op(...), L) +``` + +ComputeY1-like IR currently remains suboptimal because assignment emits: + +```text +ensure_layout(mulf(ext(...), ext(...)), deinterleaved=4) +``` + +but remat does not yet: + +```text +1. hoist result-side ensure_layout through mulf +2. rematerialize ext under a requested result layout +3. expose foldable load/group_broadcast_load helpers +``` + +## 2. Support APIs + +Add support-layer helpers in `VMILayoutSupport`. + +### 2.1 Widen Relation Query + +```cpp +FailureOr getWidenSourceLayoutForResultLayout( + VMIVRegType sourceType, + VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, + std::string *reason = nullptr) const; +``` + +Semantics: + +```text +1. source/result lane count must match. +2. result element width must be an integer multiple of source element width. +3. first implementation supports widenFactor 2 and 4. +4. requestedResultLayout must be contiguous or deinterleaved(block_elems=1). +5. requested result factor F must be divisible by widenFactor K. +6. derived source factor is F / K. +7. derived source factor 1 means contiguous. +8. derived source/result layout pair must be accepted by ext support gates. +``` + +Examples: + +```text +f16 -> f32, requested result deinterleaved=4 + => source deinterleaved=2 + +f16 -> f32, requested result deinterleaved=2 + => source contiguous + +f8 -> f32, requested result deinterleaved=4 + => source contiguous +``` + +### 2.2 Ext Support Gates + +Update `getExtFSupport`, `getExtSISupport`, and `getExtUISupport` so they accept +relation-rematerialized local shapes: + +```text +source layout: + contiguous or deinterleaved(S, block_elems=1) + +result layout: + deinterleaved(S * widenFactor, block_elems=1) +``` + +Keep group_slots integer extension behavior unchanged. + +Reject: + +```text +1. result layout that is not deinterleaved for dense ext. +2. block_elems != 1 in this first implementation. +3. source/result arity that does not satisfy resultArity = factor * sourceArity. +4. unsupported element width relation. +``` + +`vmi-to-vpto` ext lowering already works from physical source/result arity. If +support admits `source deinterleaved=2 -> result deinterleaved=4`, lowering must +be covered by tests. + +## 3. Rematerialize Pass Changes + +Extend `VMILayoutRematerialize.cpp` around `VMIEnsureLayoutOp`. + +Recommended ordering for one helper: + +```text +try relation-aware ext remat +try result-side layout-transparent producer remat +try existing cheap construction remat +``` + +The pass should use a helper worklist. When one rewrite creates new +`ensure_layout` helpers, enqueue them so the same pass can continue locally. + +### 3.1 Ext Remat Pattern + +Match: + +```text +%wanted = pto.vmi.ensure_layout %old +%old = pto.vmi.extf %src +``` + +where `%wanted` has `requestedResultType`. + +Rewrite: + +```text +derivedSourceLayout = + support.getWidenSourceLayoutForResultLayout(srcType, requestedResultType, + requestedResultLayout) + +%src2 = materialize source to derivedSourceLayout +%new = pto.vmi.extf %src2 : derivedSourceType -> requestedResultType +replace %wanted with %new +``` + +Equivalent patterns are needed for: + +```text +pto.vmi.extf +pto.vmi.extsi +pto.vmi.extui +``` + +The source materialization step should: + +```text +1. reuse %src if it already has derivedSourceLayout. +2. create pto.vmi.ensure_layout otherwise. +3. enqueue the new helper for further remat/fold opportunities. +``` + +### 3.2 Layout-Transparent Result Helper Remat + +Match: + +```text +%wanted = pto.vmi.ensure_layout %old +%old = pto.vmi.mulf %lhs, %rhs +``` + +Rewrite: + +```text +%lhs2 = ensure_layout %lhs : lhsLayout -> requestedLayout +%rhs2 = ensure_layout %rhs : rhsLayout -> requestedLayout +%new = pto.vmi.mulf %lhs2, %rhs2 : requestedLayout +replace %wanted with %new +``` + +Initial op coverage: + +```text +mulf +addf/addi/subf/subi/muli/divf/minf/maxf +andi/ori/xori/shli/shrui +negf/absf/absi/sqrt/exp/ln/relu/not +fma +``` + +Optional later coverage: + +```text +cmpf/cmpi: + result is mask, so this belongs with ensure_mask_layout. + +select: + requires coordinated data and mask layout/granularity handling. +``` + +The pass must preserve op attributes exactly. + +### 3.3 Existing Cheap Producer Remat + +Keep current behavior for: + +```text +splat pto.vmi.constant +pto.vmi.broadcast +pto.vmi.iota +pto.vmi.create_mask +pto.vmi.create_group_mask +pto.vmi.constant_mask +``` + +These remain direct remat cases and do not require relation queries. + +## 4. Fold Pass Interaction + +Relation remat may create producer-side helpers: + +```text +ensure_layout(load(...), deinterleaved=2) +ensure_layout(group_broadcast_load(...), deinterleaved=2) +``` + +`vmi-layout-fold` should absorb these when the producer can directly materialize +the requested layout. + +Existing load fold should use producer capability, not helper materialization +capability. A load may directly produce a requested contiguous or +deinterleaved=2/4 block_elems=1 result layout even when the helper conversion +from the old load layout to the requested layout would not be a legal register +materialization. + +Add fold coverage if missing for: + +```text +group_broadcast_load result layout requested as deinterleaved=2/block_elems=1 +group_slot_broadcast_load result layout requested as deinterleaved=2/block_elems=1 +``` + +The fold pass must still be local: + +```text +load/group_broadcast_load + ensure_layout + => cloned/retyped producer with requested result layout +``` + +It must not inspect downstream `ext` or `trunc`. + +## 5. Pipeline + +Use a pipeline with fold after remat: + +```text +vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto +``` + +The first fold handles helpers already emitted by assignment. The second fold +handles helpers exposed by relation-aware remat. + +If later result-side remat and operand-side sink need to alternate for longer +chains, the driver may repeat: + +```text +vmi-layout-rematerialize +canonicalize/cse +vmi-layout-fold +canonicalize/cse +``` + +Keep the first implementation single-pass unless tests prove a fixed point is +needed. + +## 6. Tests + +Add focused lit tests. + +### 6.1 Direct Ext Remat + +Input shape: + +```text +load f16 +extf f16 -> f32 +ensure_layout ext result deinterleaved=2 -> deinterleaved=4 +truncf f32 -> f8 +``` + +Check after: + +```text +vmi-layout-rematerialize +``` + +```text +extf source is deinterleaved=2 +extf result is deinterleaved=4 +old ensure_layout is gone +``` + +Check after: + +```text +vmi-layout-rematerialize -vmi-layout-fold -vmi-to-vpto +``` + +```text +load uses deinterleaved load lowering when fold is available +extf lowers from local source/result arity +``` + +### 6.2 Elementwise Result Helper Remat + +Input shape: + +```text +extf lhs -> deinterleaved=2 +extf rhs -> deinterleaved=2 +mulf lhs, rhs -> deinterleaved=2 +ensure_layout mulf result -> deinterleaved=4 +truncf +``` + +Check: + +```text +mulf is cloned/rebuilt with deinterleaved=4 operands/results +each ext is rematerialized as source deinterleaved=2 -> result deinterleaved=4 +no ensure_layout remains between mulf and truncf +``` + +### 6.3 Multi-Consumer Conflict + +Input shape: + +```text +ext result deinterleaved=2 +consumer A uses deinterleaved=2 +consumer B has ensure_layout to deinterleaved=4 +``` + +Check: + +```text +original ext remains for consumer A +new cloned ext feeds consumer B +no global layout selection is required +``` + +### 6.4 ComputeY1 + +Run: + +```text +pto-test-opt compute_y1_to_fp8_fp16_vmi.pto \ + -vmi-layout-assignment \ + -vmi-layout-rematerialize \ + -vmi-layout-fold \ + -vmi-to-vpto +``` + +Expected: + +```text +x load can become deinterleaved=2 and lower through deinterleaved load support +scale path can keep the E2B-compatible deinterleaved layout +mulf/truncf path has no deinterleaved=2 -> deinterleaved=4 helper immediately +before truncf +``` + +## 7. Non-Goals And Follow-Ups + +Do not implement in this change: + +```text +1. assignment relation propagation. +2. global layout cost model. +3. trunc/narrow relation remat. +4. cloning memory loads in remat without going through explicit fold support. +5. context-sensitive vmi-to-vpto lowering. +``` + +Follow-ups: + +```text +1. Add narrow relation remat for selected trunc patterns after widen is stable. +2. Add select/cmp mask-aware result helper remat. +3. Consider a fixed-point layout optimization pipeline if long chains need it. +4. Move repeated op-family cloning utilities into a shared helper if the pass + grows beyond the first ext/elementwise implementation. +``` diff --git a/docs/designs/vmi-layout-request-propagation.md b/docs/designs/vmi-layout-request-propagation.md new file mode 100644 index 0000000000..c59477d2e9 --- /dev/null +++ b/docs/designs/vmi-layout-request-propagation.md @@ -0,0 +1,1390 @@ +# VMI Layout Request Propagation + +This document describes VMI layout request propagation for layout assignment +and local layout rewrites. It is intentionally independent from any single +optimization case such as `group_reduce -> truncf -> group_broadcast`. + +The propagator answers this question: + +```text +Given one or more requested layouts for VMI values, can the surrounding IR be +rewritten so those semantic values are available in those layouts? +``` + +It must not work by clearing existing layout attributes. Existing layout +attributes are the current IR state. A pass expresses desired changes by adding +value-layout requests. + +## Legacy Propagation Model + +Before the shared propagator, VMI layout propagation was split between layout +assignment and post-assignment rematerialization. + +`vmi-layout-assignment` used an equivalence-class model. The pass collected +values whose layouts must be identical, then assigned one layout to each class. +This worked for layout-transparent relations: + +```text +source/result of elementwise ops +bitcast-like ops +control-flow forwarded values +mask/data same-layout relations +``` + +In that model, assignment propagation means "same layout flows through this +relation". It does not cross relations where the connected values naturally +have different layouts. Casts, reductions, channel transforms, broadcasts, and +other width- or shape-changing operations cannot be represented as a simple +equivalence-class union. When the requested layout on one side did not match +the layout chosen for the other side, assignment had to connect the two sides +with `ensure_layout` materialization. + +`vmi-layout-rematerialize` then performed a second, optimization-oriented +propagation after assignment. It started from explicit `ensure_layout` ops, +looked at the producer feeding the `ensure_layout`, and decided whether the +layout conversion could be pushed through or removed by recomputing the +producer. This was intentionally local and peephole-like: + +```text +ensure_layout result requested as layout B + inspect producer in layout A + ask whether this producer can be recomputed for B cheaply + clone/recompute the producer or one of its operands + replace this one use with the recomputed value + continue recursively while the local cost model allows it +``` + +This pass therefore acted like an instcombine-style rematerialization pass over +layout helper IR. It did not build a graph-level layout solution first. The +propagation happened as a sequence of one-by-one IR rewrites and replacements, +driven by the `ensure_layout` currently being optimized. + +## Refactor Motivation + +The split model caused the same layout knowledge to appear in two places. +Assignment needed rules for natural/preferred layouts and same-layout +constraints. Rematerialization also needed to know which non-equivalence +operations could be crossed, and what layout should appear on the other side +after crossing. As more cast and reduction cases were added, the duplicated +parts became the risky part: + +```text +vcvt/ext/trunc layout facts +group-value cast layout facts +reduce input/result layout facts +broadcast/group-slot layout support checks +``` + +Adding another rematerialization case would have required copying more of the +assignment-side rule set into the peephole optimizer. That makes the result +order-sensitive and hard to reason about: assignment may choose one layout, +rematerialization may rediscover a related layout later, and the two passes can +silently disagree about whether the same op relation is legal. + +The refactor centralizes the reusable part as op-local transfer relations and a +value-layout propagator: + +```text +propagate(value, layout) + record the requested layout for the semantic SSA value + run transfer relations through defining ops and users + derive uniquely implied layouts on connected values + record conflicts when a second layout is needed + materialize unresolved conflicts with ensure_layout during apply +``` + +Assignment remains responsible for policy: it chooses the initial seeds from +true layout decision points such as ext/trunc/reduce results, loads, stores, and +target-specific boundary requirements. Same-layout ops are relations, not +decision roots. Consumer operand requests are not producer facts unless a +support-checked direct producer can really adopt that layout. + +Rematerialization should then consume the same transfer/support helpers instead +of owning a second copy of cast/reduce layout logic. Its job becomes local IR +cleanup after a consistent layout table exists: folding helper IR, cloning cheap +producers when that is profitable, or removing redundant `ensure_layout` chains. +It should not be the only place where non-equivalence layout propagation is +defined. + +## Core Assignments + +The propagator is a pass-local object. It does not rewrite IR while requests +are being added. Its core assignment table is value-centric: + +```text +assignments: + Value -> VMIValueLayoutAssignment + +worklist: + Primary Value/layout facts that were just added and still need to be + propagated through defining and user op transfers. Conflict layouts do not + enter this worklist unless a later rematerialization/fork planner explicitly + creates a new producer instance for that alternate layout. +``` + +The assignment uses short field names: + +```text +VMIValueLayoutAssignment: + layout: + The layout that apply must make available for this SSA value. + + conflicts: + Extra layouts required for this SSA value. There are two forms: + def-side, meaning the SSA value itself must also be available in another + layout; and use-side, meaning one operand use must see the value through + another layout. + +VMILayoutConflict: + operand: + Empty for a def-side conflict. Otherwise, the OpOperand to rewrite during + apply if the conflict remains material. Do not create a fake operand for a + def-side conflict. + + layout: + The extra layout required for this SSA value or operand use. +``` + +The implementation can represent this directly as: + +```text +VMILayoutConflict: + OpOperand *operand // nullptr means def-side + VMILayoutAttr layout +``` + +Conflict uniqueness is checked by conflict form: + +```text +def-side: + key = layout + duplicate layout is a no-op + +use-side: + key = operand + duplicate operand with the same layout is a no-op + duplicate operand with a different layout is a hard conflict +``` + +`assignment.layout` is not a promise that the defining op can directly produce +that layout. It is the layout that apply must make available for the semantic +value. Apply implements it by rewriting the source value's VMI type when the +value is type-rewriteable inside the current rewrite scope, or by creating a +primary `ensure_layout` value at a boundary when the source value's type cannot +be rewritten. + +A conflict does not overwrite `assignment.layout`. It records an additional +fork requirement: + +```text +def-side conflict: + if assignment.layout differs from conflict.layout: + materialize ensure_layout assignment.layout -> conflict.layout at the + def/boundary + +use-side conflict: + if assignment.layout differs from conflict.layout: + insert ensure_layout assignment.layout -> conflict.layout before + conflict.operand + replace only conflict.operand + else: + the conflict is an identity fork and is dropped +``` + +This keeps the `assignment.layout` model value-bound while still representing +multiple layout requirements. Operand layout requirements are not stored in a +separate global table and are not encoded inside `VMILayoutAttr`. + +The propagator does not rank competing layouts for a value. The caller decides +the initial request order. The first layout accepted for a value becomes +`assignment.layout`; subsequent different layouts become conflicts. Only +`assignment.layout` is a producer-transfer fact. A conflict means an alternate +layout must be materialized as a fork from the primary value; it is not a claim +that the original defining op result has that alternate layout. Cost-based +layout choice is outside this first implementation and would change only the +merge policy, not transfer relations. + +Do not pre-build separate tables for current layouts, requests, +materializations, rewrites, or conflicts: + +```text +current layout: + Read from value.getType() on demand. If the type has no layout, the current + layout is unknown, not contiguous. + +request: + Immediately merge into assignments. If `assignment.layout` is newly added, + enqueue the value/layout fact. Conflict layouts are recorded in + `assignment.conflicts` but are not propagated through producers by this + utility. The worklist de-duplicates by `(Value, VMILayoutAttr)` so the same + primary fact is propagated once. The first implementation does not keep a + separate request log. + +materialization: + Derive during apply by comparing each value assignment's layout with its + conflicts and with the current IR type. + +rewrite: + Derive during apply from assignments by finding defining ops whose result + types need to change. + +hard conflict: + Detect while merging one operand's required layout. Def-side conflicting + value requests are represented as conflicts; same-operand conflicting + requests still fail. +``` + +## API Shape + +The basic API is: + +```text +request(value, layout): + request that the SSA value be available in layout. If this differs from the + existing `assignment.layout`, record a def-side conflict. + +request(operand, layout): + request that operand.get() be available in layout for this operand only. This + does not create the primary assignment for the source value. If the source + value's `assignment.layout` is absent or different, record a use-side + requirement on that operand. + +run(): + propagate until the worklist is empty + +apply(): + rewrite in-scope VMI value types to their assigned layouts + materialize boundary values and conflicts with ensure_layout forks +``` + +The merge operation is the only place that mutates the core table: + +```text +request(value, layout): + if assignments[value].layout exists and differs from layout: + add assignments[value].conflicts[{def, layout}] + else if assignments[value].layout is absent: + assignments[value].layout = layout + push (value, layout) to worklist + +request(operand, layout): + value = operand.get() + if assignments[value].layout exists and differs from layout: + if operand already has a different conflict layout: + report hard conflict + add or update assignments[value].conflicts[operand] = layout + else if assignments[value].layout exists and equals layout: + no-op + else: + create assignments[value] if needed + add assignments[value].conflicts[operand] = layout +``` + +`request(operand, layout)` must not delegate to `request(value, layout)`. A +consumer operand requirement is not a producer fact. It becomes a local +materialization request unless a separate value request or transfer relation +chooses the same primary layout for the source value. + +Transfer relations call the API that matches the derived fact. For +layout-transparent ops, an operand value fact can derive the op result primary +layout, while the other operands receive operand-local layout requirements. +A result value fact similarly derives operand-local requirements. Cast inverse +relations also call `request(operand, layout)` because a result layout +determines the layout needed by that cast operand, not necessarily the +producer's global primary layout. The first implementation may promote an +operand-local request into a producer seed in two narrow cases: + +- the source value is defined by a direct layout producer, such as a supported + `load`, splat-like `constant`, `broadcast`, `iota`, or `group_broadcast` + whose source group-slot layout and requested result layout pass the + target-support query; +- the source value is defined by a single-use layout-transparent op, every data + operand can directly produce the requested layout, and neither the value nor + any data operand already has a different primary assignment or current layout. + +The second case is a local rematerialization-style choice for layout-free IR. +It does not chase arbitrary producer chains, and it must not override an +existing primary assignment, an assignment seed, or an explicit current layout. +Assignment must add producer value seeds before consumer operand requests so a +consumer request cannot steal the primary layout from an ext/trunc/reduce value +that already has a preferred layout. If an alternate layout is cheap but +requires cloning or sinking an existing producer chain, a later +rematerialization/fold pass may remove the inserted `ensure_layout` by cloning +or folding the producer at the use site. + +If a subsequent request asks for a different extra layout, that request is +recorded as a def-side or use-side +conflict depending on which overload created it. + +Existing explicit type layouts are part of the current IR state, not the +request input. A pass should not first write an explicit layout and then ask +the propagator to rediscover it. Instead, the pass requests the desired +layout directly. If an explicit layout already exists, the propagator reads it +on demand as the current state when deciding whether materialization is needed. + +An existing explicit layout is not a lock. If a request asks for a layout that +differs from the current IR type or `assignment.layout`, the propagator +records the requested value/layout fact and continues propagation. During +apply, the propagator decides whether the defining op can be rewritten to that +layout by ordinary in-scope type rewrite because sourceValue is +type-rewriteable, or whether a primary `ensure_layout` fork is needed at a +boundary. + +## Implementation Shape + +The first implementation should be a small utility, not a replacement for every +layout pass at once: + +```text +include/PTO/Transforms/VMILayoutPropagation.h +lib/PTO/Transforms/VMILayoutPropagation.cpp +``` + +The public type should expose value layout requests, operand-local value layout +requests, fixed-point propagation, and final IR application: + +```text +class VMILayoutPropagator { + LogicalResult request(Value value, VMILayoutAttr layout); + LogicalResult request(OpOperand &operand, VMILayoutAttr layout); + LogicalResult run(); + LogicalResult apply(RewriterBase &rewriter); + + VMILayoutAttr getRequestedOrCurrentLayout(Value value) const; + VMILayoutAttr getRequestedLayout(Value value) const; +}; +``` + +`getRequestedLayout` returns only `assignment.layout`, not conflict layouts. +`getRequestedOrCurrentLayout` reads `assignment.layout` first, then the current +VMI type layout. It returns an empty layout when neither exists. It must not +default to contiguous during propagation. Passes that need to materialize or +inspect extra layouts should use the value assignment during apply instead of +overloading this singular accessor. + +The current assignment pass already has pieces that map directly onto this +utility: + +```text +LayoutSolver::setNaturalLayout + becomes request(value, layout). + +LayoutSolver::requestDataUse + becomes request(operand, layout). If the request conflicts with the source + value's `assignment.layout`, the propagator records a use-side conflict on that + value. + +LayoutSolver::getExplicitDataLayout + becomes a current-layout read from the value type or assigned equivalence. + +LayoutSolver::getDataLayout + currently defaults unknown to contiguous. The propagator must not do that + until finalization/apply. + +LayoutSolver::applyConsumerDrivenDataLayouts + is removed. Consumer operand requirements are represented as use-side + conflicts and do not become natural layouts. + +LayoutSolver::rewriteDataTypes and insertDataUseMaterializations + provide the first implementation material for VMILayoutPropagator::apply in + vmi-layout-assignment. +``` + +This means the first change can be staged: + +```text +1. Add VMILayoutPropagator with assignments, per-value def/use conflicts, + value-layout fact worklist, and strict same-operand conflict checking. + +2. Move layout relation helpers into it. Run mask granularity assignment/split + before layout propagation, and keep the layout propagator focused on layout + only. Do not move all materialization logic in the first step. + +3. Let vmi-layout-assignment call the propagator for the relations that need + order-independent propagation. + +4. Factor existing assignment finalization into VMILayoutPropagator::apply + once the assignment table is authoritative. +``` + +## Propagation + +The propagator's loop is deterministic: + +```text +while worklist is not empty: + pop changed value/layout fact + inspect the defining op transfer relation, if the value is an op result + inspect each user op transfer relation + operand/source value layout known -> derive result value layouts + result value layout known -> derive operand source value layouts when + inverse is legal +``` + +The propagator should work with generic relation records, not with hard-coded +patterns: + +```text +same-layout relation: + elementwise ops, bitcast, select-compatible values + +cast width relation: + source layout known -> derive result + result layout known -> derive source when inverse is legal + +channel split/merge relation: + source/input side layout known -> derive result side layout, and vice versa + when the inverse is unique + +control-flow relation: + CFG branch/yield/region layout consistency + call/return only when function/call boundary rewrite is enabled +``` + +Each relation should use the same helper that support checks use. For example, +a cast width relation should know only about source/result layout relations of +cast ops; it must not know about `truncf -> group_broadcast` as a pattern. + +### Transfer Relation Details + +The core asset is the static transfer relation: + +```text +known value + known layout + op-local rule + -> zero or more uniquely derived layouts for connected source/result values +``` + +This is different from assignment-specific policy such as "natural layout" or +"consumer request". Those policies decide where the first request comes from. +The transfer relation only answers whether a known layout on one connected +value uniquely determines layouts on other connected values. + +The first useful transfer relation set is: + +```text +same-layout relation: + For layout-transparent ops, if one operand/result gets a requested layout, + require the same layout on the connected values. Existing assignment + `unite` logic is a source of the supported op list. + +cast width relation: + Use VMILayoutSupport cast fact helpers. Dense and group-value casts should + both generate VMICastLayoutFact pairs. Group-value casts need a support + helper that derives: + narrow: result.LS = source.LS * width_ratio + widen: source.LS = result.LS * width_ratio + +channel split/merge relation: + Split and merge have fixed source/result layout equations once the channel + count is known. +``` + +The relation providers should call `VMILayoutSupport` for legality instead of +duplicating target checks. Missing support helpers should be added there, not +open-coded in the propagator. + +### Existing Assignment Transfer Assets By Op + +`vmi-layout-assignment` already contains several transfer relations. They +should be extracted by op or op family. Each op family may expose more than +one transfer rule, but every rule must still be local to that op. + +```text +VMIAddF/AddI/SubF/SubI/MulF/MulI/DivF/MinF/MaxF/AndI/OrI/XOrI: + Existing code: + constrainElementwiseBinary(...) + + Transfer rules: + lhs layout -> rhs layout and result layout, same layout + rhs layout -> lhs layout and result layout, same layout + result layout -> lhs layout and rhs layout, same layout + + Notes: + existing code has fallback handling for unsupported group_broadcast result + layouts; that fallback is assignment policy, not the transfer rule. + +VMIFma: + Existing code: + unite(lhs, rhs), unite(lhs, acc), unite(lhs, result) + + Transfer rules: + any of lhs/rhs/acc/result layout + -> all of lhs/rhs/acc/result use the same layout + +VMINegF/AbsF/AbsI/Sqrt/Exp/Ln/Relu/Not: + Existing code: + unite(source, result) + + Transfer rules: + source layout -> result layout, same layout + result layout -> source layout, same layout + +VMIFPToSI/VMISIToFP: + Existing code: + unite(source, result) + + Transfer rules: + source layout -> result layout, same layout + result layout -> source layout, same layout + +VMICmpF/VMICmpI: + Existing code: + unite(lhs, rhs) + + Transfer rules: + lhs layout -> rhs layout, same layout + rhs layout -> lhs layout, same layout + lhs/rhs layout -> mask result layout, same layout + mask result layout -> lhs/rhs layout, same layout when unique + +VMISelect: + Existing code: + unite(trueValue, falseValue), unite(trueValue, result) + + Transfer rules: + any of trueValue/falseValue/result layout + -> all of trueValue/falseValue/result use the same layout + selected value/result layout -> mask operand layout, same layout + mask operand layout -> selected value/result layout, same layout when + unique + +VMIBitcast: + Existing code: + unite(source, result) + + Transfer rules: + source layout -> result layout, same layout + result layout -> source layout, same layout + +Widen cast transfer (VMIExtF/ExtSI/ExtUI): + Existing code: + ExtF uses getPreferredCastLayoutFact(...) for dense widening. + ExtSI/ExtUI have a source group_slots slots=8 branch. + getPreferredCastLayoutFact(...) + + Transfer rules: + These three ops are the same widening transfer relation at the layout + level. Float/integer signedness affects element-type legality and lowering, + not the layout algebra. + dense source layout -> dense result layout when the cast fact is + unique and supported + dense result layout -> dense source layout by matching the same cast facts + group-value source layout -> group-value result layout for supported + widening + group-value result layout -> group-value source layout when the matching + cast fact is unique + + Extraction work: + ExtF does not currently have the group-value branch that ExtSI/ExtUI have. + This is a legacy assignment-framework artifact. Add one widening-cast fact + generator used by all three ops and let VMILayoutSupport decide which + element types are legal. + +VMITruncF: + Existing code: + VMILayoutSupport::getPreferredCastLayoutFact(...) + TruncF source group_slots slots=1 case + + Transfer rules: + dense source layout -> dense result layout when the cast fact is + unique and supported + dense result layout -> dense source layout when the matching cast fact is + unique and supported + group-value source layout -> group-value result layout for supported + narrowing + group-value result layout -> group-value source layout when the matching + cast fact is unique + + Extraction work: + current group-value branch supports only slots=1 and is source-driven. + Add dense trunc inverse generation to the shared cast fact helper. + +VMITruncI: + Existing code: + VMILayoutSupport::getPreferredCastLayoutFact(...) + TruncI source group_slots slots=1/8 branch + + Transfer rules: + dense source layout -> dense result layout when the cast fact is + unique and supported + dense result layout -> dense source layout when the matching cast fact is + unique and supported + group-value source layout -> group-value result layout for supported + narrowing + group-value result layout -> group-value source layout when the matching + cast fact is unique + + Extraction work: + current group-value branch is source-driven. The narrow4 slots=8 case + already computes lane_stride=4; move that equation into the shared cast + fact helper. + +VMIChannelSplit: + Existing code: + VMIChannelSplitOp case in addConstraints() + + Transfer rules: + source deinterleaved=channel_count -> each result contiguous + result contiguous on every result -> source deinterleaved=channel_count + +VMIChannelMerge: + Existing code: + VMIChannelMergeOp case in addConstraints() + + Transfer rules: + every input contiguous -> result deinterleaved=channel_count + result deinterleaved=channel_count -> every input contiguous + +control-flow ops: + Existing code: + addIfConstraints, addYieldConstraints, addExecuteRegionConstraints, + addIndexSwitchConstraints, addWhileConstraints, addForConstraints, + addBranchConstraints, addReturnConstraints, addCallConstraints + + Transfer rules: + any equivalent incoming/yield/result/call value layout + -> same layout on every value in that equivalence group + +mask ops: + Existing code: + uniteMask(...) + + Transfer rules: + same-layout mask propagation mirrors data same-layout propagation +``` + +Mask layout propagation is not a separate assignment flow. Mask values +participate in the same propagator/worklist as VMI data values. Data-producing +or data-consuming ops drive their mask operands/results through same-layout +relations: + +```text +data layout L -> mask layout L +mask layout L -> data layout L when the relation is unique +``` + +Mask granularity assignment is separate from layout propagation and runs before +layout assignment. Different granularities represent different mask values, so +the granularity pass should split or materialize mask values before the layout +propagator sees them. After that split, each mask SSA value has one fixed +granularity, and the layout propagator only assigns its layout. + +Granularity must not become a second independent request dimension in the +layout worklist, and different granularity requirements must not be represented +as layout conflicts on one mask value. + +Some assignment logic is not a transfer relation and should not be moved into +the propagator as if it were one: + +```text +producer-only layout choice: + group_reduce/group_load/group_slot_load/group_broadcast_load choosing an + initial result layout is assignment policy. It can seed assignments, but it + is not derived from another operand layout. + +store-only requirements: + store/group_store/masked_store have no result value to infer. They seed an + operand layout request for their value operand when the store form requires a + concrete input layout. They are not bidirectional transfer relations. + +group_broadcast pair support: + VMILayoutSupport can validate a source/result pair, but source layout alone + does not always uniquely choose a dense result layout. Treat it as a support + check and source requirement unless another value request makes the result + layout concrete. +``` + +### Relation Mechanics + +Each op does not own a persistent layout table. The only persistent table is +the propagator's global `assignments`. + +An op relation is implemented by a transfer object: + +```text +class VMILayoutTransfer: + propagate(op, changedValue, changedLayout, propagator) +``` + +`propagate` is an op-local fact propagation method. It does not rewrite IR. +It requests layouts for connected source and result values. It should be a +thin wrapper around the op family's pure relation evaluator: + +```text +derive(op, changedPort, changedLayout, assignmentView) + -> zero or more derived port/layout facts + +propagate(op, changedValue, changedLayout, propagator): + inspect op operands/results, attrs, element types, and VMILayoutSupport + facts = derive(op, changedPort, changedLayout, propagator.assignments) + for each result fact: + call request(resultValue, derivedLayout) + for each operand fact: + call request(operand, derivedLayout) +``` + +The evaluator is query-like in the ordinary sense: it is pure, it does not +mutate `assignments`, it does not enqueue work, and it does not rewrite IR. It +is not a public propagator `query` API because the propagator API that changes +state is still `request`. This keeps the mutation point explicit while letting +propagation and apply-time validation consume the same op relation. + +When the connected value is reached through an operand, `propagate` calls the +operand overload so a merge conflict can become a use-side conflict on the +source value. + +This is the mechanism that propagates layout information. When one connected +value receives a layout, the op relation may infer layouts for other connected +values: + +```text +same-layout op: + any operand source/result layout -> all connected source/result values get + the same layout + +cast op: + source value layout -> result value layout using width ratio + result value layout -> source value layout when the inverse relation is legal + +channel split/merge: + channel count plus one side layout -> the other side layout when the inverse + is unique +``` + +Not every relation is symmetric, and not every input layout determines every +other operand. If the relation cannot derive a unique supported layout, it +emits nothing. If a requested layout differs from the source value's +`assignment.layout`, the request is recorded as a conflict on that value. The +value overload records a def-side conflict; the operand overload records a +use-side conflict. If the same operand records two different conflict layouts, +strict propagation fails. + +Block arguments are also worklist values. They are not `OpResult`s and do not +have `getDefiningOp()`, but the propagator can still process them through a +boundary transfer: + +```text +process(value, layout): + if value is an OpResult: + process the defining op result port + if value is a BlockArgument: + process the block/function boundary port + process all uses of value +``` + +The block/function boundary transfer is separate from ordinary op-result +transfer: + +```text +ordinary op result: + defining op result <-> defining op operands/results + +block argument: + block argument <-> predecessor terminator successor operands + +function argument: + function argument <-> function signature / call operands, when the pass owns + signature or interprocedural rewrite +``` + +CFG block arguments require a same-transfer. A block argument and each +predecessor terminator successor operand represent the same semantic stream, so +layout requests must propagate in both directions: + +```text +block argument layout L -> request predecessor successor operand layout L +predecessor successor operand layout L -> request block argument layout L +``` + +If a CFG edge cannot satisfy the same layout as the block argument, the +terminator operand request becomes a use-side conflict on the predecessor +source value and apply materializes that edge operand before the terminator. + +Function signatures and call sites are a separate boundary. The first +implementation can leave function/call boundary transfer out if it only rewrites +inside one function and does not update function signatures or call sites. In +that mode, function arguments are boundary source values: they propagate to +their users, and primary boundary materialization is handled by apply. + +### Cast Layout Facts + +Width-changing casts should keep source/result layout information paired. Do +not encode individual concrete vector cases directly in `propagate`, such as +`64xf16 -> 64xf32`. Generate layout facts from the source/result element widths +and the known anchor layout: + +```text +VMICastLayoutFact: + sourceLayout + resultLayout +``` + +`propagate` uses those facts mechanically: + +```text +if changedValue is the cast source: + for each fact whose sourceLayout == changedLayout: + request result value layout = fact.resultLayout + +if changedValue is the cast result: + for each fact whose resultLayout == changedLayout: + call request(sourceOperand, fact.sourceLayout) +``` + +The support layer should provide a fact generator shaped like: + +```text +getCastLayoutFacts(sourceType, resultType, anchorSide, anchorLayout) + -> zero or more VMICastLayoutFact +``` + +The `anchorSide` is source or result. Together, `anchorSide` and +`anchorLayout` limit generation to the small set of facts that can match the +currently propagated layout. + +For width-changing dense casts, fact generation is keyed by the width ratio and +the known layout family. For a narrowing cast with: + +```text +R = source_bits / result_bits +``` + +a deinterleaved source layout maps as: + +```text +source deinterleaved = F + if F % R == 0: + result factor = F / R + result factor 1 means contiguous +``` + +Examples: + +```text +f32 -> f8, R=4: + source 256xf32 deinterleaved=4 + -> result 256xf8 contiguous + +f32 -> f16, R=2: + source 256xf32 deinterleaved=4 + -> result 256xf16 deinterleaved=2 +``` + +The inverse direction matches the same facts: + +```text +f8 result contiguous + -> f32 source deinterleaved=4 + +f16 result deinterleaved=2 + -> f32 source deinterleaved=4 +``` + +The propagator should accept an inverse only when fact generation returns a +single supported matching fact for the anchor side. If several facts could +satisfy the same side, the relation must not guess; it should emit nothing +unless another request makes the choice concrete. + +Group-value casts use a different static relation: + +```text +narrow by R: + result.LS = source.LS * R + +widen by R: + source.LS = result.LS * R +``` + +These fact generators belong in shared support helpers so assignment, +rematerialization, validation, and lowering agree on the same relation. + +## Apply + +Apply derives concrete IR actions from `assignments` and the current IR. +These actions do not need to be stored in separate propagator tables before +apply: + +Apply must not run a second propagation-style relation query over results or +operands. By the time apply starts, `run()` has already reached a fixed point: + +```text +assignment.layout records the layout to make available for each value +assignment.conflicts records def-side and use-side alternate layouts +every def/user relation has already propagated its required layouts +every use that cannot consume assignment.layout has already become a conflict +``` + +Therefore apply does not ask the op relation again. It writes the layouts +already recorded in `assignments` into IR: + +```text +apply must not create new layout requests +apply must not discover new layout conflicts +apply only consumes assignment.layout and assignment.conflicts +``` + +```text +sourceValue: + the original SSA value before apply + +currentLayout: + the layout carried by sourceValue's current VMI type + +assignedLayout: + assignment.layout for sourceValue + +assignedValue: + the SSA value that carries assignedLayout after apply + if sourceValue is type-rewriteable, assignedValue is sourceValue after its + VMI type is rewritten to assignedLayout + otherwise assignedValue is ensure_layout sourceValue : + currentLayout -> assignedLayout when currentLayout != assignedLayout + +def-side conflict: + materialize an extra SSA value from assignedValue to conflict.layout near + the def or boundary + +use-side conflict: + materialize an extra SSA value from assignedValue to conflict.layout before + that use +``` + +Primary type rewrite is not a producer-specific optimization. It is the normal +way assignment becomes explicit in IR. It does not choose a layout and does not +ask whether the producer relation supports the layout; propagation already did +that. It only updates the VMI type to `assignment.layout`. + +Primary type rewrite is available for op results whose defining op is inside +the rewrite scope and whose result type can be changed without rewriting an +external ABI boundary. Multi-result ops should be rewritten as one op update +using the final assignments for all assigned results. Block arguments, +function arguments, values defined outside the rewrite scope, and ABI boundary +values use `ensure_layout` materialization instead. + +In this document, `type-rewriteable` means exactly: + +```text +the value is an OpResult +the defining op is inside the current rewrite scope +changing the result VMI type does not rewrite an external ABI boundary +``` + +It does not mean the defining op was queried again for a preferred layout. + +Primary materialization is not conflict-driven. A value can have no conflicts +and still need `ensure_layout` when its current IR type cannot be rewritten to +`assignment.layout`: + +```text +function argument current layout A +only in-scope use requests layout B + +assignment.layout = B +assignment.conflicts = empty + +if the function signature is not rewritten: + arg_B = ensure_layout arg : A -> B + use(arg_B) +``` + +Conflicts only describe extra layouts besides `assignment.layout`. They do not +replace the primary action that makes `assignment.layout` available. + +Producer-specific improvements, such as folding a fallback +`load -> ensure_layout` into a load with the requested layout or rematerializing +a cast across an `ensure_layout`, are not part of apply. They should run as +ordinary layout-fold/rematerialization over explicit helper IR. + +```text +1. Make assignment.layout available. + Let sourceValue be the original SSA value, currentLayout be the layout on + sourceValue's current VMI type, and assignedLayout be assignment.layout. + If sourceValue is type-rewriteable, rewrite its VMI type to assignedLayout + and use sourceValue as assignedValue. Otherwise, if + currentLayout == assignedLayout, assignedValue is sourceValue. Otherwise + insert ensure_layout sourceValue : currentLayout -> assignedLayout and use + its result as assignedValue. If that ensure_layout is not supported by + VMILayoutSupport, apply fails. + +2. Materialize def-side conflicts. + For each def-side conflict layout, if assignedValue already has that layout, + reuse assignedValue. Otherwise insert ensure_layout near the value's + definition or boundary. If that ensure_layout is not supported by + VMILayoutSupport, apply fails. The ensure_layout result is the materialized + SSA value for that layout; no persistent container is needed to represent it. + +3. Rewrite non-conflicting uses. + Uses in the rewrite scope are redirected to assignedValue unless a use-side + conflict records a different layout for that operand. Do not implement this + as an unconditional replace-all-uses. Iterate the original uses and skip + operands recorded in use-side conflicts. Do not query the user op relation + here; non-conflicting uses were already accepted during propagation. + +4. Materialize use-side conflicts. + For each use-side conflict, insert ensure_layout from assignedValue to + conflict.layout before conflict.operand and replace only that operand. + If that ensure_layout is not supported by VMILayoutSupport, apply fails. Do + not special-case the old producer layout here. Redundant chains such as + l1 -> l2 -> l1 are folded by a separate layout-fold/rematerialization pass. +``` + +Concrete insertion points: + +```text +op result: + normally rewrite the result VMI type in place. If the result is outside the + rewrite scope or crosses an ABI boundary, insert the fallback primary + ensure_layout immediately after the defining op. + +block argument: + insert the fallback primary ensure_layout at the first legal insertion point + of the owning block. + +function argument: + insert the fallback primary ensure_layout at the first legal insertion point + of the entry block. + +def-side conflict: + insert ensure_layout after assignedValue is available, using the same + def/boundary placement as the primary materialization. + +use-side conflict: + insert ensure_layout immediately before conflict.operand.getOwner() and + replace only conflict.operand. +``` + +Apply should keep a local materialization map keyed by `(Value, Layout, +placement)` so the same required layout at the same placement is not emitted +twice. Different use-side conflicts may still materialize separately when a +single def-side value would not dominate all uses. + +The source for conflict materialization is always `assignedValue`, not the +operand's old value: + +```text +def-side conflict layout C: + c = ensure_layout assignedValue : assignment.layout -> C + +use-side conflict operand op.i requiring layout C: + c = ensure_layout assignedValue : assignment.layout -> C + op.i = c +``` + +If `C == assignment.layout`, the conflict is an identity and no +`ensure_layout` is inserted. + +A def-side conflict by itself does not replace arbitrary uses of the original +SSA value. It only materializes another layout view at the definition or +boundary because a value-level request asked for that layout. Use-side +conflicts are still materialized locally from assignedValue. + +For example, a block argument with current layout `A` and requested +`assignment.layout` `B` keeps its original type unless the rewrite scope allows +changing the boundary. Apply inserts `ensure_layout A -> B` near the boundary +and rewires in-scope uses to the materialized value, except for operands that +have explicit use-side conflicts. + +For a normal defining op inside the rewrite scope, apply rewrites the result +type in place: + +```text +before: + a = producer() : layout A + use(a) + +after assignment.layout = B: + a = producer() : layout B + use(a) +``` + +For a value that cannot be rewritten in place, apply materializes the assigned +layout with `ensure_layout`: + +```text +before: + a0 = boundary_value : layout A + use(a0) + +after assignment.layout = B: + a1 = boundary_value : layout A + a = ensure_layout a1 : A -> B + use(a) +``` + +If one use still requires layout `A`, apply emits the local materialization +from the assigned value: + +```text +after assignment.layout = B, with one use-side conflict requiring A: + a = producer() : layout B + c = ensure_layout a : B -> A + use(c) +``` + +For the non-rewrite fallback, the same conflict may produce an `A -> B -> A` +chain. That chain is not a special case in apply. A separate +layout-fold/rematerialization pass may fold it back to the original boundary +value when legal. + +For `vmi-layout-assignment`, the existing apply path is already usable: + +```text +rewriteDataTypes: + Sets VMI value types to `assignment.layout`. + +insertDataUseMaterializations: + Inserts pto.vmi.ensure_layout before operand uses whose requested layout does + not match the source value type. + +rewriteMaskTypes / insertMaskUseMaterializations: + Reused after first-phase mask granularity assignment/split and mask layout + propagation. +``` + +For post-assignment optimization passes, apply must be more conservative: + +```text +rewritable value: + The value is an op result inside the rewrite scope and changing the VMI type + does not cross an external ABI boundary. + +non-rewritable value: + Function/block arguments, external boundaries, or values outside the pass + rewrite scope keep their original type. If they have a requested + assignment.layout different from the current explicit layout, apply inserts + a def-side ensure_layout inside the rewrite scope and rewires in-scope uses + to the materialized value. +``` + +Whether a value is rewritable is derived from the IR and the caller's rewrite +scope. It is not stored in `assignments`. + +## Conflicts + +The propagator distinguishes representable conflicts from hard conflicts: + +```text +def-side conflict: + A source value's assignment.layout differs from another value-level request. + Record VMILayoutConflict{def, layout}. Apply will materialize it as an + ensure_layout fork near the value definition or boundary if it remains + different. + +use-side conflict: + A source value's assignment.layout differs from one operand's required layout. + Record VMILayoutConflict{operand, layout}. Apply will materialize it as an + ensure_layout fork if it remains different. + +hard operand conflict: + The same operand is requested as two different layouts. Do not create two + forks for one operand. The first implementation fails the current + propagation request. +``` + +The initial conflict policy should be strict inside the propagator: + +```text +same value requested as two different layouts: + keep the first layout as assignment.layout, record subsequent layouts as + def-side conflicts. Propagate only the primary assignment.layout fact. + +same operand requested as two different layouts: + fail the propagation request with a diagnostic at the requesting operation. +``` + +Value-level and operand-level layout differences are not hard conflicts when +they can be represented by an unambiguous fork. They are recorded in the +source value's `conflicts` list and are materialized by +VMILayoutPropagator::apply. + +## Assignment Shape + +`vmi-layout-assignment` can use the propagator as: + +```text +collect op layout constraints and relations +request natural layouts for producers that choose concrete layouts +request layouts required by consumers +propagate the value-layout table through op relations +apply ensure_layout / ensure_mask_layout materialization +validate assigned VMI IR +``` + +Later layout optimization passes should follow the same model: + +```text +request a new layout for one or more anchor values +propagate the value-layout table through registered op relations +materialize mismatches at values or uses +run layout-fold/rematerialization to remove redundant helper IR +``` + +Manual layout clearing is unsafe because it loses boundary contracts and can +turn an already validated assigned IR back into an ambiguous pre-assignment IR. + +## First Implementation Boundary + +The first implementation is deliberately limited: + +```text +included: + data value layout propagation + value-level requests + def-side and operand-level conflicts stored inside each value assignment + strict same-operand conflict diagnostics + same-layout data op transfer + mask granularity assignment/split before layout propagation + mask layout propagation + CFG/block-argument same-transfer for control-flow values + cast width relations needed by group-value cast/broadcast + reuse of existing assignment type rewrite and data ensure_layout insertion + +not included: + function signature and call-site interprocedural rewrite + a separate operand-request table outside value assignments + cost-based layout choice + best-effort request dropping + global replacement of fold/rematerialize/sink passes +``` + +This boundary makes the design implementable without forcing all existing VMI +layout passes to move at once. + +Acceptance requirement: + +```text +Existing VMI lit and simulator regression outcomes must not regress after the +refactor. Any test that passed before the propagator refactor must still pass +after it. If a test's expected IR shape changes because the new propagation is +more canonical, update the expectation only with an explicit before/after +reason in the change description. +``` + +## Anti-Specialization Rules + +The propagator must not grow optimization-pattern-specific boundary checks. +These forms are not acceptable: + +```text +if producer is group_reduce and op is truncf and user is group_broadcast: + choose layout X + +if value is function argument and consumer is some specific op: + insert special materialization Y +``` + +Boundary handling should be expressed through generic materialization support, +not producer-specific pattern checks inside the propagator: + +```text +canMaterializeLayout(sourceType, resultType): + delegate to VMILayoutSupport::getDataLayoutMaterializationSupport +``` + +Op-specific logic is still necessary, but it must be local to one op and one +role. Do not combine several ops into one pattern: + +```text +cast transfer: + source/result width relation only + +channel split/merge transfer: + channel-count layout equation only + +group_broadcast support/request: + source operand group-value requirement and source/result support only + +group_reduce seed: + initial group-value result layout choice only + +store request/support: + required operand layout and store support only +``` + +With this structure, `group_reduce -> truncf -> group_broadcast` works because +independent op-local rules compose through `assignments`, not because the +propagator recognizes that whole chain. + +## Example: Group-Value Cast + +For a group-value cast relation: + +```text +if the source value is requested/propagated as group-value: + derive result group-value layout + request the result value layout + +if the result value is requested/propagated as group-value: + derive inverse source value layout + request the source value layout through the source operand overload +``` + +If the derived source layout conflicts with the source value's existing +`assignment.layout`, the operand overload records a use-side conflict in the +source value's `assignment.conflicts`. Apply materializes it with +`ensure_layout` if the layouts still differ. That conflict is not propagated +back through the producer. Propagating an alternate layout through the producer +would mean rematerializing or cloning that producer for the alternate layout, +which is a separate optimization and must create a real forked value. + +## Pre-Refactor Regression Baseline + +Baseline captured on 2026-07-02 before introducing the shared VMI layout +propagator implementation. + +```text +git HEAD: + 4a2a100f + +working tree notes: + unrelated pre-existing changes were present in 3rdparty/PTO-Gym and + docs/designs/vmi-layout-lowering-cases.md. + untracked local investigation files were also present and are not part of + this baseline. +``` + +VMI lit baseline: + +```bash +export PATH="/home/mouliangyu/projects/github.com/vpto-dev/llvm-project/build-shared/bin:$PATH" +python3 /home/mouliangyu/projects/github.com/vpto-dev/llvm-project/llvm/utils/lit/lit.py \ + -v -j16 build/test/lit/vmi +``` + +Result: + +```text +Total Discovered Tests: 379 +Passed: 379 +Failed: 0 +``` + +VMI simulator baseline: + +```bash +WORK_SPACE=/tmp/ptoas-vmi-baseline-latest/sim \ +CASE_PREFIX='vmi/' \ +JOBS=16 \ +test/vpto/scripts/run_host_vpto_validation_parallel.sh +``` + +Result: + +```text +Total cases: 85 +PASS: 81 +FAIL: 4 +``` + +Existing simulator failures in this baseline: + +```text +vmi/group-reduce-s16-truncf-broadcast-store +vmi/group-reduce-s64-slot-add-store +vmi/group-reduce-s64-broadcast-reduce-store +vmi/group-reduce-s64-truncf-store +``` + +The refactor acceptance point is equality-or-better against this baseline: +all 379 VMI lit tests must keep passing, and the VMI simulator run must not add +new failing cases or turn any of the 81 passing cases into failures. diff --git a/docs/designs/vmi-mxfp8-32x32-expected-lowering.md b/docs/designs/vmi-mxfp8-32x32-expected-lowering.md new file mode 100644 index 0000000000..5130da3ec6 --- /dev/null +++ b/docs/designs/vmi-mxfp8-32x32-expected-lowering.md @@ -0,0 +1,236 @@ +# VMI MXFP8 32x32 Expected VPTO Lowering + +本文记录 `test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto` +的预期 VPTO lower 结果。输入 VMI case 在 `vecscope` 内按 8 行一组循环, +每次处理一个 `256xf32` tile,也就是 8 行 x 32 列。 + +这里写的是设计目标,不是当前 `--emit-vpto` 的实际输出。重点是把 E8M0 +scale 的内存效果写明确:每个 8x32 chunk 产生 8 个 scale byte。lowering +按 CCE 风格先写到 32B 对齐的 padded UB slot,再通过 UB->GM copy 的 +`src_stride=32B, dst_stride=8B` 消除 UB padding,使 GM 端仍然连续。 + +## Complete Expected PTO File + +```mlir +module attributes {pto.backend = "vpto", pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @vmi_tquant_mxfp8_32x32_nd_kernel(%src_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %false = arith.constant false + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %c4_i32 = arith.constant 4 : i32 + %c5_i32 = arith.constant 5 : i32 + %c6_i32 = arith.constant 6 : i32 + %c7_i32 = arith.constant 7 : i32 + %c8_i32 = arith.constant 8 : i32 + %c23_i32 = arith.constant 23 : i32 + %c24_i32 = arith.constant 24 : i32 + %c40_i32 = arith.constant 40 : i32 + %c48_i32 = arith.constant 48 : i32 + %c56_i32 = arith.constant 56 : i32 + %c254_i32 = arith.constant 254 : i32 + %c2139095040_i32 = arith.constant 2139095040 : i32 + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_fp8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_fp8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_e8m0 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_ubuf %src_gm, %ub_src, %c0_i64, %c1_i64, %c4096_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c4096_i64, %c4096_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %out_fp8_gm, %ub_out_fp8_u8, %c0_i64, %c1_i64, %c1024_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c1024_i64, %c1024_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_flag[, , ] + pto.wait_flag[, , ] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c8 { + %elem_off = arith.muli %row, %c32 : index + %elem_off_64 = arith.addi %elem_off, %c64 : index + %elem_off_128 = arith.addi %elem_off, %c128 : index + %elem_off_192 = arith.addi %elem_off, %c192 : index + + %x0 = pto.vlds %ub_src[%elem_off] : !pto.ptr -> !pto.vreg<64xf32> + %x1 = pto.vlds %ub_src[%elem_off_64] : !pto.ptr -> !pto.vreg<64xf32> + %x2 = pto.vlds %ub_src[%elem_off_128] : !pto.ptr -> !pto.vreg<64xf32> + %x3 = pto.vlds %ub_src[%elem_off_192] : !pto.ptr -> !pto.vreg<64xf32> + + %d0, %d1 = pto.vdintlv %x0, %x1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d2, %d3 = pto.vdintlv %x2, %x3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d4, %d5 = pto.vdintlv %d0, %d2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d6, %d7 = pto.vdintlv %d1, %d3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %all_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %slot8_b32 = pto.pge_b32 "PAT_VL8" : !pto.mask + %vl8_b32 = pto.pset_b32 "PAT_VL8" : !pto.mask + %vl16_b32 = pto.pset_b32 "PAT_VL16" : !pto.mask + %vl24_b32, %unused24 = pto.plt_b32 %c24_i32 : i32 -> !pto.mask, i32 + %vl32_b32 = pto.pset_b32 "PAT_VL32" : !pto.mask + %vl40_b32, %unused40 = pto.plt_b32 %c40_i32 : i32 -> !pto.mask, i32 + %vl48_b32, %unused48 = pto.plt_b32 %c48_i32 : i32 -> !pto.mask, i32 + %vl56_b32, %unused56 = pto.plt_b32 %c56_i32 : i32 -> !pto.mask, i32 + + %abs0 = pto.vabs %d4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs1 = pto.vabs %d6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs2 = pto.vabs %d5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs3 = pto.vabs %d7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %g0 = pto.vcgmax %abs0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g1 = pto.vcgmax %abs1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g2 = pto.vcgmax %abs2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g3 = pto.vcgmax %abs3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g01 = pto.vmax %g0, %g1, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g23 = pto.vmax %g2, %g3, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %amax = pto.vmax %g01, %g23, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %amax_i32 = pto.vbitcast %amax : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %exp_mask = pto.vdup %c2139095040_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %shift = pto.vdup %c23_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %emax = pto.vdup %c8_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %scale_exp_bias = pto.vdup %c254_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %exp_bits = pto.vand %amax_i32, %exp_mask, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %exp = pto.vshr %exp_bits, %shift, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %e8m0_payload_i32 = pto.vsub %exp, %emax, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + + %idx0 = pto.vdup %c0_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx1 = pto.vdup %c1_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx2 = pto.vdup %c2_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx3 = pto.vdup %c3_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx4 = pto.vdup %c4_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx5 = pto.vdup %c5_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx6 = pto.vdup %c6_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx7 = pto.vdup %c7_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + + %not_vl8 = pto.pnot %vl8_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_8_15 = pto.pand %vl16_b32, %not_vl8, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_1 = pto.vsel %idx1, %idx0, %range_8_15 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl16 = pto.pnot %vl16_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_16_23 = pto.pand %vl24_b32, %not_vl16, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_2 = pto.vsel %idx2, %broadcast_idx_1, %range_16_23 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl24 = pto.pnot %vl24_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_24_31 = pto.pand %vl32_b32, %not_vl24, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_3 = pto.vsel %idx3, %broadcast_idx_2, %range_24_31 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl32 = pto.pnot %vl32_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_32_39 = pto.pand %vl40_b32, %not_vl32, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_4 = pto.vsel %idx4, %broadcast_idx_3, %range_32_39 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl40 = pto.pnot %vl40_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_40_47 = pto.pand %vl48_b32, %not_vl40, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_5 = pto.vsel %idx5, %broadcast_idx_4, %range_40_47 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl48 = pto.pnot %vl48_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_48_55 = pto.pand %vl56_b32, %not_vl48, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_6 = pto.vsel %idx6, %broadcast_idx_5, %range_48_55 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl56 = pto.pnot %vl56_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_56_63 = pto.pand %all_b32, %not_vl56, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx = pto.vsel %idx7, %broadcast_idx_6, %range_56_63 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + + %scale_u16 = pto.vpack %e8m0_payload_i32, "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %scale_u8 = pto.vpack %scale_u16, "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> + %scale_slot = arith.divui %row, %c8 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + %scale8_b8 = pto.pge_b8 "PAT_VL8" : !pto.mask + pto.vsts %scale_u8, %ub_out_e8m0[%scale_ub_off], %scale8_b8 {dist = "NORM_B8"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + + %scale_exp = pto.vsub %scale_exp_bias, %e8m0_payload_i32, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %scale_bits = pto.vshl %scale_exp, %shift, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %scale_f32 = pto.vbitcast %scale_bits : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %scale_vec = pto.vselr %scale_f32, %broadcast_idx : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + + %m0 = pto.vmul %d4, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m1 = pto.vmul %d6, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m2 = pto.vmul %d5, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m3 = pto.vmul %d7, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %i0, %i1 = pto.vintlv %m0, %m2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i2, %i3 = pto.vintlv %m1, %m3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i4, %i5 = pto.vintlv %i0, %i2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i6, %i7 = pto.vintlv %i1, %i3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r0, %r1 = pto.vdintlv %i4, %i5 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r2, %r3 = pto.vdintlv %i6, %i7 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r4, %r5 = pto.vdintlv %r0, %r2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r6, %r7 = pto.vdintlv %r1, %r3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %all_b8 = pto.pset_b8 "PAT_ALL" : !pto.mask + %q0 = pto.vcvt %r4, %all_b32 {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q1 = pto.vcvt %r6, %all_b32 {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q2 = pto.vcvt %r5, %all_b32 {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q3 = pto.vcvt %r7, %all_b32 {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q01 = pto.vor %q0, %q1, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q012 = pto.vor %q01, %q2, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q = pto.vor %q012, %q3, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + pto.vsts %q, %ub_out_fp8_f8[%elem_off], %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + } + } + + pto.set_flag[, , ] + pto.wait_flag[, , ] + pto.copy_ubuf_to_gm %ub_out_fp8_u8, %out_fp8_gm, %c0_i64, %c1_i64, %c1024_i64, %c0_i64, %c1024_i64, %c1024_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.copy_ubuf_to_gm %ub_out_e8m0, %out_e8m0_gm, %c0_i64, %c4_i64, %c8_i64, %c0_i64, %c8_i64, %c32_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier + return + } + } +} +``` + +## Scale Store Contract + +上面的 lower 对每次循环执行一条 `NORM_B8` store,写到 32B 对齐的 UB +slot: + +```text +row = 0 -> UB[0..7], UB[8..31] padding +row = 8 -> UB[32..39], UB[40..63] padding +row = 16 -> UB[64..71], UB[72..95] padding +row = 24 -> UB[96..103], UB[104..127] padding +``` + +最终 copy-out 只搬每个 slot 的前 8B: + +```text +copy len = 8B +repeat = 4 +source stride = 32B +destination stride = 8B +``` + +因此 GM 端效果仍然是连续 scale 输出: + +```text +GM[0..7] <- UB[0..7] +GM[8..15] <- UB[32..39] +GM[16..23] <- UB[64..71] +GM[24..31] <- UB[96..103] +``` diff --git a/docs/isa/micro-isa/10-reduction-ops.md b/docs/isa/micro-isa/10-reduction-ops.md index ecae818f2c..2129f91ce0 100644 --- a/docs/isa/micro-isa/10-reduction-ops.md +++ b/docs/isa/micro-isa/10-reduction-ops.md @@ -206,7 +206,9 @@ VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] - **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** i16-i32, f16, f32 -- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). +- **semantics:** Sum within each 32-byte VLane. The 8 VLane results are written + continuously to the low lanes of the destination vector. For f32, results are + at indices 0, 1, 2, 3, 4, 5, 6, 7. ```c int K = N / 8; // elements per VLane @@ -214,17 +216,17 @@ for (int g = 0; g < 8; g++) { T sum = 0; for (int i = 0; i < K; i++) sum += src[g*K + i]; - dst[g*K] = sum; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = sum; } -// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +for (int i = 8; i < N; i++) + dst[i] = 0; +// For f32: results at dst[0], dst[1], ..., dst[7]. ``` - **inputs:** `%input` is the source vector and `%mask` selects participating lanes. - **outputs:** `%result` contains one sum per 32-byte VLane group, written - contiguously into the low slot of each group. + continuously to the low lanes of the destination vector. - **constraints and limitations:** This is a per-32-byte VLane-group reduction. Inactive lanes are treated as zero. @@ -242,10 +244,10 @@ for (int g = 0; g < 8; g++) { T mx = -INF; for (int i = 0; i < K; i++) if (src[g*K + i] > mx) mx = src[g*K + i]; - dst[g*K] = mx; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mx; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -268,10 +270,10 @@ for (int g = 0; g < 8; g++) { T mn = INF; for (int i = 0; i < K; i++) if (src[g*K + i] < mn) mn = src[g*K + i]; - dst[g*K] = mn; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mn; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -320,7 +322,7 @@ for (int i = 1; i < N; i++) // Row-wise sum using vcgadd (for 8-row tile) %row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 +// Results at indices 0, 1, 2, 3, 4, 5, 6, 7 // Full vector sum for normalization %total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 429b5e232e..ed8166ad12 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -38,6 +38,8 @@ class PTO_Attr traits = []> let mnemonic = attrMnemonic; } +include "PTO/IR/VMIAttrs.td" + //===----------------------------------------------------------------------===// // Address Space //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 97d6ff6a9b..ce2a780edd 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -76,6 +76,7 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VMIOps.td" include "PTO/IR/VPTOOps.td" //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 5fbe9d8d45..a6ac0ad106 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -338,4 +338,5 @@ def F4E2M1x2Type : TypeDef { + let summary = "VMI logical vector register layout"; + let parameters = (ins + StringRefParameter<"layout kind">:$kind, + "int64_t":$factor, + "int64_t":$blockElems, + "int64_t":$slots, + "int64_t":$laneStride + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static VMILayoutAttr getContiguous(::mlir::MLIRContext *context, + int64_t laneStride = 1); + static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, + int64_t factor, + int64_t blockElems = 1, + int64_t laneStride = 1); + static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, + int64_t numGroups, + int64_t slots = 0, + int64_t laneStride = 1); + + bool isContiguous() const { return getKind() == "contiguous"; } + bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + bool isGroupSlots() const { return getKind() == "num_groups"; } + bool isDense() const { return isContiguous() || isDeinterleaved(); } + int64_t getNumGroups() const { return getFactor(); } + bool hasDenseLaneStride() const { + return isDense() && getLaneStride() != 1; + } + bool hasGroupSlotLaneStride() const { + return isGroupSlots() && getLaneStride() != 1; + } + bool hasLaneStride() const { return getLaneStride() != 1; } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIATTRS diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td new file mode 100644 index 0000000000..152b712339 --- /dev/null +++ b/include/PTO/IR/VMIOps.td @@ -0,0 +1,742 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIOps.td - PTO VMI semantic operations -------------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMIOPS +#define MLIR_DIALECT_PTO_IR_VMIOPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def VMI_VRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType>($_self)">, + "VMI logical vector register type">; + +def VMI_MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIMaskType>($_self)">, + "VMI logical mask type">; + +def VMI_ValueTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType, ::mlir::pto::VMIMaskType>($_self)">, + "VMI logical vector or mask type">; + +def PTO_PhysicalVRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VRegType>($_self)">, + "PTO physical vector register type">; + +def PTO_PhysicalMaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self)">, + "PTO physical mask type">; + +def PTO_PhysicalVMIPartTypeConstraint : AnyTypeOf< + [PTO_PhysicalVRegTypeConstraint, PTO_PhysicalMaskTypeConstraint], + "PTO physical vector register or mask type">; + +class VMI_Op traits = []> + : PTO_Op<"vmi." # mnemonic, traits>; + +def VMIConstantOp : VMI_Op<"constant", [Pure]> { + let summary = "VMI logical vector constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIBroadcastOp : VMI_Op<"broadcast", [Pure]> { + let summary = "Broadcast one scalar or 1-lane VMI vector to a VMI logical vector"; + let arguments = (ins AnyType:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)"; +} + +def VMIIotaOp : VMI_Op<"iota", [Pure]> { + let summary = "Create a VMI logical index vector from a scalar base"; + let arguments = (ins + AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$base, + OptionalAttr:$order + ); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$base attr-dict `:` type($base) `->` type($result)"; +} + +def VMICreateMaskOp : VMI_Op<"create_mask", [Pure]> { + let summary = "Create a VMI logical prefix predicate mask"; + let arguments = (ins Index:$active_lanes); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; +} + +def VMICreateGroupMaskOp : VMI_Op<"create_group_mask", [Pure]> { + let summary = "Create a VMI logical grouped predicate mask"; + let description = [{ + Creates a mask where lane i is active iff + `(i % group_size) < active_elems_per_group`. + }]; + let arguments = (ins + Index:$active_elems_per_group, + I64Attr:$num_groups, + I64Attr:$group_size + ); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_elems_per_group attr-dict `:` type($active_elems_per_group) `->` type($result)"; +} + +def VMIConstantMaskOp : VMI_Op<"constant_mask", [Pure]> { + let summary = "VMI logical predicate mask constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIMaskAndOp : VMI_Op<"mask_and", [Pure]> { + let summary = "VMI logical predicate mask and"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskOrOp : VMI_Op<"mask_or", [Pure]> { + let summary = "VMI logical predicate mask or"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskXOrOp : VMI_Op<"mask_xor", [Pure]> { + let summary = "VMI logical predicate mask xor"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskNotOp : VMI_Op<"mask_not", [Pure]> { + let summary = "VMI logical predicate mask not"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAddFOp : VMI_Op<"addf", [Pure]> { + let summary = "VMI floating-point elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIAddIOp : VMI_Op<"addi", [Pure]> { + let summary = "VMI integer elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubFOp : VMI_Op<"subf", [Pure]> { + let summary = "VMI floating-point elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubIOp : VMI_Op<"subi", [Pure]> { + let summary = "VMI integer elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulFOp : VMI_Op<"mulf", [Pure]> { + let summary = "VMI floating-point elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulIOp : VMI_Op<"muli", [Pure]> { + let summary = "VMI integer elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIFmaOp : VMI_Op<"fma", [Pure]> { + let summary = "VMI fused floating-point multiply-add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs, + VMI_VRegTypeConstraint:$acc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)"; +} + +def VMIDivFOp : VMI_Op<"divf", [Pure]> { + let summary = "VMI floating-point elementwise divide"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMinFOp : VMI_Op<"minf", [Pure]> { + let summary = "VMI floating-point elementwise minimum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaxFOp : VMI_Op<"maxf", [Pure]> { + let summary = "VMI floating-point elementwise maximum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINegFOp : VMI_Op<"negf", [Pure]> { + let summary = "VMI floating-point elementwise negate"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsFOp : VMI_Op<"absf", [Pure]> { + let summary = "VMI floating-point elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsIOp : VMI_Op<"absi", [Pure]> { + let summary = "VMI integer elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMISqrtOp : VMI_Op<"sqrt", [Pure]> { + let summary = "VMI floating-point elementwise square root"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExpOp : VMI_Op<"exp", [Pure]> { + let summary = "VMI floating-point elementwise exponential"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILnOp : VMI_Op<"ln", [Pure]> { + let summary = "VMI floating-point elementwise natural logarithm"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIReluOp : VMI_Op<"relu", [Pure]> { + let summary = "VMI floating-point elementwise ReLU"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAndIOp : VMI_Op<"andi", [Pure]> { + let summary = "VMI integer elementwise bitwise and"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIOrIOp : VMI_Op<"ori", [Pure]> { + let summary = "VMI integer elementwise bitwise or"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIXOrIOp : VMI_Op<"xori", [Pure]> { + let summary = "VMI integer elementwise bitwise xor"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShLIOp : VMI_Op<"shli", [Pure]> { + let summary = "VMI integer elementwise left shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShRUIOp : VMI_Op<"shrui", [Pure]> { + let summary = "VMI unsigned integer elementwise right shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINotOp : VMI_Op<"not", [Pure]> { + let summary = "VMI integer elementwise bitwise not"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMICmpFOp : VMI_Op<"cmpf", [Pure]> { + let summary = "VMI floating-point elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMICmpIOp : VMI_Op<"cmpi", [Pure]> { + let summary = "VMI integer elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISelectOp : VMI_Op<"select", [Pure]> { + let summary = "VMI elementwise select"; + let arguments = (ins VMI_MaskTypeConstraint:$mask, VMI_VRegTypeConstraint:$true_value, + VMI_VRegTypeConstraint:$false_value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask `,` $true_value `,` $false_value attr-dict `:` type($mask) `,` type($true_value) `,` type($false_value) `->` type($result)"; +} + +def VMIActivePrefixIndexOp : VMI_Op<"active_prefix_index"> { + let summary = "VMI per-lane active-prefix index from a predicate mask"; + let arguments = (ins VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask attr-dict `:` type($mask) `->` type($result)"; +} + +def VMICompressOp : VMI_Op<"compress"> { + let summary = "VMI compact active source lanes according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMICompressStoreOp : VMI_Op<"compress_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI store active source lanes contiguously according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIReduceAddIOp : VMI_Op<"reduce_addi"> { + let summary = "VMI masked integer add reduction with a 1-lane vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceAddFOp : VMI_Op<"reduce_addf"> { + let summary = "VMI masked floating-point add reduction with explicit reassociation permission"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMaxFOp : VMI_Op<"reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction with a 1-lane vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { + let summary = "VMI masked floating-point minimum reduction with a 1-lane vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { + let summary = "VMI masked floating-point add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupReduceMaxFOp : VMI_Op<"group_reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { + let summary = "VMI masked integer add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupReduceMaxIOp : VMI_Op<"group_reduce_maxi"> { + let summary = "VMI masked integer maximum reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { + let summary = "VMI broadcast group-slot values back to each logical group"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +class VMIHistogramOp + : VMI_Op { + let summary = summaryText; + let arguments = (ins VMI_VRegTypeConstraint:$acc, + VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$acc `,` $source `,` $mask attr-dict `:` type($acc) `,` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIDhistOp : VMIHistogramOp<"dhist", + "VMI full 256-bin distribution histogram over unsigned 8-bit source lanes">; + +def VMIChistOp : VMIHistogramOp<"chist", + "VMI full 256-bin cumulative histogram over unsigned 8-bit source lanes">; + +def VMIExtFOp : VMI_Op<"extf", [Pure]> { + let summary = "VMI floating-point elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncFOp : VMI_Op<"truncf", [Pure]> { + let summary = "VMI floating-point elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + OptionalAttr:$rounding); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIFPToSIOp : VMI_Op<"fptosi", [Pure]> { + let summary = "VMI floating-point to signed integer elementwise conversion"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMISIToFPOp : VMI_Op<"sitofp", [Pure]> { + let summary = "VMI signed integer to floating-point elementwise conversion"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExtSIOp : VMI_Op<"extsi", [Pure]> { + let summary = "VMI signed integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExtUIOp : VMI_Op<"extui", [Pure]> { + let summary = "VMI unsigned integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncIOp : VMI_Op<"trunci", [Pure]> { + let summary = "VMI saturating integer elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIBitcastOp : VMI_Op<"bitcast", [Pure]> { + let summary = "VMI bitwise vector reinterpretation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIDeinterleaveLoadOp : VMI_Op<"deinterleave_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI two-way logical deinterleave load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let results = (outs VMI_VRegTypeConstraint:$low, VMI_VRegTypeConstraint:$high); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($low) `,` type($high)"; +} + +def VMIGroupLoadOp : VMI_Op<"group_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector load with a row stride between groups"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$row_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $row_stride attr-dict `:` type($source) `->` type($result)"; +} + +def VMIGroupSlotLoadOp : VMI_Op<"group_slot_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load one scalar value per logical group into group slots"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$source_group_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $source_group_stride attr-dict `:` type($source) `->` type($result)"; +} + +def VMIGroupBroadcastLoadOp : VMI_Op<"group_broadcast_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load one scalar value per logical group and broadcast it to group lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$source_group_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $source_group_stride attr-dict `:` type($source) `->` type($result)"; +} + +def VMIStrideLoadOp : VMI_Op<"stride_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI block-strided vector load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + I16:$block_stride, I16:$repeat_stride, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($source) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) `->` type($result)"; +} + +def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector load with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIGatherOp : VMI_Op<"gather", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed gather with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($indices) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIExpandLoadOp : VMI_Op<"expand_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load a dense active-lane stream into masked logical lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, Index:$offset); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIInterleaveStoreOp : VMI_Op<"interleave_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI two-way logical interleave store"; + let arguments = (ins VMI_VRegTypeConstraint:$low, VMI_VRegTypeConstraint:$high, + PtrOrMemRef:$destination, Index:$offset); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$low `,` $high `,` $destination `[` $offset `]` attr-dict `:` type($low) `,` type($high) `,` type($destination)"; +} + +def VMIGroupStoreOp : VMI_Op<"group_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector store with a row stride between groups"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, Index:$row_stride, I64Attr:$num_groups); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $row_stride attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIStrideStoreOp : VMI_Op<"stride_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI block-strided vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, I16:$block_stride, I16:$repeat_stride, + VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask)"; +} + +def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed scatter"; + let arguments = (ins VMI_VRegTypeConstraint:$value, + PtrOrMemRef:$destination, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $indices `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($indices) `,` type($mask)"; +} + +def VMIShuffleOp : VMI_Op<"shuffle", [Pure]> { + let summary = "VMI static lane shuffle"; + let arguments = (ins VMI_VRegTypeConstraint:$source, DenseI64ArrayAttr:$indices); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIChannelSplitOp : VMI_Op<"channel_split"> { + let summary = "VMI split interleaved logical channels"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs Variadic:$results); + let hasVerifier = 1; +} + +def VMIChannelMergeOp : VMI_Op<"channel_merge"> { + let summary = "VMI merge logical channels by interleaving"; + let arguments = (ins Variadic:$inputs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIEnsureLayoutOp : VMI_Op<"ensure_layout", [Pure]> { + let summary = "Internal VMI data layout materialization helper"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout", [Pure]> { + let summary = "Internal VMI mask layout materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskGranularityOp : VMI_Op<"ensure_mask_granularity", [Pure]> { + let summary = "Internal VMI mask granularity materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIUnpackOp : VMI_Op<"unpack"> { + let summary = "Internal VMI value projection to physical parts"; + let arguments = (ins VMI_ValueTypeConstraint:$source); + let results = (outs Variadic:$parts); + let hasVerifier = 1; +} + +def VMIPackOp : VMI_Op<"pack"> { + let summary = "Internal physical parts materialized as one VMI value"; + let arguments = (ins Variadic:$parts); + let results = (outs VMI_ValueTypeConstraint:$result); + let hasVerifier = 1; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIOPS diff --git a/include/PTO/IR/VMITypeDefs.td b/include/PTO/IR/VMITypeDefs.td new file mode 100644 index 0000000000..4ec6bb5009 --- /dev/null +++ b/include/PTO/IR/VMITypeDefs.td @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMITypeDefs.td - PTO VMI type definitions -----------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMITYPEDEFS +#define MLIR_DIALECT_PTO_IR_VMITYPEDEFS + +include "PTO/IR/PTODialect.td" +include "PTO/IR/PTOAttrs.td" + +def VMIVRegType : TypeDef { + let mnemonic = "vmi.vreg"; + let summary = "A VMI logical vector register value"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + bool hasLayout() const { return static_cast(getLayout()); } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +def VMIMaskType : TypeDef { + let mnemonic = "vmi.mask"; + let summary = "A VMI logical predicate mask value"; + + let parameters = (ins + "int64_t":$elementCount, + StringRefParameter<"mask granularity view">:$granularity, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + static bool isConcreteGranularity(::llvm::StringRef granularity); + + bool hasLayout() const { return static_cast(getLayout()); } + bool isPred() const { return getGranularity() == "pred"; } + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMITYPEDEFS diff --git a/include/PTO/IR/VMIUtils.h b/include/PTO/IR/VMIUtils.h new file mode 100644 index 0000000000..e55e558034 --- /dev/null +++ b/include/PTO/IR/VMIUtils.h @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIUtils.h - PTO VMI shared helpers ----------------------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_IR_VMIUTILS_H +#define PTO_IR_VMIUTILS_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::pto { + +inline constexpr StringLiteral kVMIDiagUnsupported = "VMI-UNSUPPORTED"; +inline constexpr StringLiteral kVMIDiagLayoutContract = + "VMI-LAYOUT-CONTRACT"; +inline constexpr StringLiteral kVMIDiagPassInvariant = "VMI-PASS-INVARIANT"; +inline constexpr StringLiteral kVMIDiagResidualOp = "VMI-RESIDUAL-OP"; + +inline constexpr StringLiteral kVMIDiagUnsupportedPrefix = + "VMI-UNSUPPORTED: "; +inline constexpr StringLiteral kVMIDiagLayoutContractPrefix = + "VMI-LAYOUT-CONTRACT: "; +inline constexpr StringLiteral kVMIDiagPassInvariantPrefix = + "VMI-PASS-INVARIANT: "; +inline constexpr StringLiteral kVMIDiagResidualOpPrefix = "VMI-RESIDUAL-OP: "; + +struct VMIPhysicalLane { + int64_t part = 0; + int64_t chunk = 0; + int64_t lane = 0; +}; + +FailureOr getDataLanesPerPart(Type elementType); +FailureOr getMaskLanesPerPart(StringRef granularity); +FailureOr getVMIPhysicalArity(Type type); +FailureOr mapLogicalLaneToPhysical(Type type, + int64_t logicalLane); +FailureOr mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, int64_t lane); +FailureOr isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane); + +} // namespace mlir::pto + +#endif // PTO_IR_VMIUTILS_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 59ad36c932..8859d766e0 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -94,12 +94,27 @@ std::unique_ptr createPTOFusionLoadStoreElisionPass(); std::unique_ptr createPTOFlattenFusionRegionPass(); std::unique_ptr createVPTOPtrNormalizePass(); std::unique_ptr createVPTOPtrCastCleanupPass(); +std::unique_ptr createVPTONormalizeEquivalentVcvtPass(); LogicalResult validateVPTOAuthoringIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); LogicalResult validateVPTOEmissionIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +LogicalResult validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVMILayoutAssignedIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr, + bool verifyHelperSupport = true); +std::unique_ptr createPTOValidateVMIIRPass(); +std::unique_ptr createPTOValidateVMILayoutIRPass(); +std::unique_ptr createVMIPreAssignmentCombinePass(); +std::unique_ptr createVMILayoutAssignmentPass(); +std::unique_ptr createVMILayoutFoldPass(); +std::unique_ptr createVMILayoutRematerializePass(); +std::unique_ptr createVMILayoutSinkMaterializationPass(); +std::unique_ptr createVMILegalizeArithSelectPass(); +std::unique_ptr createVMIToVPTOPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index a897034d15..cf17b9b234 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -629,6 +629,171 @@ def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { "mlir::scf::SCFDialect"]; } +def PTOValidateVMIIR : Pass<"pto-validate-vmi-ir", "ModuleOp"> { + let summary = "Validate VMI producer-boundary semantic IR"; + let description = [{ + Checks that VMI producer-boundary IR uses only surface VMI data/mask types, + native pto.vmi semantic ops, and structural control-flow/function ops. This + pass runs before layout assignment, so layout-assigned VMI types, VMI helper + ops, and physical VPTO register types are rejected. + }]; + let constructor = "mlir::pto::createPTOValidateVMIIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVMILayoutIR + : Pass<"pto-validate-vmi-layout-ir", "ModuleOp"> { + let summary = "Validate layout-assigned VMI IR"; + let description = [{ + Checks the post-layout-assignment VMI stage: every VMI data value must have + a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 + granularity and layout, physical VPTO register values must not appear yet, + and VMI typed values must stay inside VMI semantic/helper or structural ops. + vmi-to-vpto chooses deterministic lowerings from the current op's attrs, + operand/result types, layouts, and operand values. Non-local choices must + be represented as explicit attrs, helper ops, or diagnostics before this + stage. Later VMI layout optimization passes may replace helpers with + cloned/rematerialized producers, but the layout gate must not depend on + hidden producer/user context. + }]; + let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMIPreAssignmentCombine + : Pass<"vmi-pre-assignment-combine", "ModuleOp"> { + let summary = "Combine VMI operations before layout assignment"; + let description = [{ + Performs VMI-level structural combines before VMI layout assignment. This + keeps layout assignment focused on choosing and materializing layouts while + still exposing direct semantic operations to the later layout and lowering + passes. + + The pass currently rewrites the semantic pattern + `group_broadcast(group_slot_load(...))` into the equivalent + `group_broadcast_load` operation. + }]; + let constructor = "mlir::pto::createVMIPreAssignmentCombinePass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect"]; +} + +def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { + let summary = "Assign concrete VMI layouts and mask granularities"; + let description = [{ + Solves VMI layout constraints and materializes the chosen layout and mask + granularity into VMI types. This pass is the boundary between surface VMI + semantic IR and layout-assigned VMI IR. + }]; + let constructor = "mlir::pto::createVMILayoutAssignmentPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutFold : Pass<"vmi-layout-fold", "ModuleOp"> { + let summary = "Fold VMI layout materializations"; + let description = [{ + Optimizes legal layout-assigned VMI IR by folding selected ensure_layout + helpers into layout-aware producers or consumers while preserving the same + logical effect. The pass does not choose layouts by inspecting arbitrary + producer/user context for vmi-to-vpto; it only rewrites explicit helper IR + into equivalent local forms. + }]; + let constructor = "mlir::pto::createVMILayoutFoldPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutRematerialize : Pass<"vmi-layout-rematerialize", "ModuleOp"> { + let summary = "Rematerialize cheap VMI producers at layout helpers"; + let description = [{ + Optimizes legal layout-assigned VMI IR by replacing selected ensure_layout, + ensure_mask_layout, and ensure_mask_granularity helpers with cloned + producers that directly create the requested result type. The pass covers + pure construction ops, selected layout-transparent data ops, and dense + widening ext relation rematerialization. Memory, control-flow, and mask-tail + proofs remain explicit in the IR. + }]; + let constructor = "mlir::pto::createVMILayoutRematerializePass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutSinkMaterialization + : Pass<"vmi-layout-sink-materialization", "ModuleOp"> { + let summary = "Sink VMI layout materialization through transfer ops"; + let description = [{ + Optimizes legal layout-assigned VMI IR by moving matching operand + ensure_layout helpers across pure layout-transparent elementwise operations. + The rewritten IR keeps the layout conversion explicit as a result + ensure_layout, so vmi-to-vpto still lowers from local op information only. + }]; + let constructor = "mlir::pto::createVMILayoutSinkMaterializationPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILegalizeArithSelect : Pass<"vmi-legalize-arith-select", "ModuleOp"> { + let summary = "Legalize canonical arith.select over VMI values"; + let description = [{ + Rewrites scalar-condition arith.select operations that produce VMI values + back to scf.if. MLIR canonicalization may fold simple scf.if regions into + arith.select, but VMI values must not cross non-VMI semantic ops before + vmi-to-vpto. This pass restores an explicit structural control-flow form + that the VMI converter already handles. + }]; + let constructor = "mlir::pto::createVMILegalizeArithSelectPass()"; + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMIToVPTO : Pass<"vmi-to-vpto", "ModuleOp"> { + let summary = "Convert layout-assigned VMI IR to physical VPTO IR"; + let description = [{ + Converts layout-assigned VMI aggregate data/mask types to ordered physical + VPTO register and mask value lists using MLIR OneToNTypeConversion. This + pass is responsible for VMI 1:N type conversion, structural control-flow + and function/call signature conversion, and VMI semantic op physicalization. + }]; + let constructor = "mlir::pto::createVMIToVPTOPass()"; + let options = [ + Option<"enableStableGatherMaskedLoad", + "enable-stable-gather-masked-load", "bool", + /*default=*/"false", + "Reserve the stable VGATHER-based lowering path for VMI masked " + "loads; currently emits a TODO diagnostic when used."> + ]; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + def PTOValidateVPTOEmissionIR : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { let summary = @@ -777,4 +942,19 @@ def VPTOPtrCastCleanup "mlir::memref::MemRefDialect"]; } +def VPTONormalizeEquivalentVcvt + : Pass<"vpto-normalize-equivalent-vcvt", "ModuleOp"> { + let summary = "Normalize equivalent VPTO vcvt part selections"; + let description = [{ + Rewrites `pto.vcvt` operations whose `EVEN` and `ODD` part selections are + provably equivalent into the canonical `EVEN` form. The pass currently + recognizes all-true masked narrow-to-wide conversions from VPTO values with + pair-wise equivalent input lanes, such as scalar/vector broadcasts and + selected broadcast load distributions. A following CSE pass can then merge + duplicate conversions. + }]; + let constructor = "mlir::pto::createVPTONormalizeEquivalentVcvtPass()"; + let dependentDialects = ["mlir::pto::PTODialect"]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h new file mode 100644 index 0000000000..b03d12cae4 --- /dev/null +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -0,0 +1,369 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSupport.h - VMI layout support queries ------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMILAYOUTSUPPORT_H +#define PTO_TRANSFORMS_VMILAYOUTSUPPORT_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir::pto { + +class VMITargetCapabilityRegistry; + +enum class VMIContiguousStoreSupportKind { + ContiguousVsts, + LaneStride2PackedVsts, + LaneStride4PackedVsts, + Deinterleaved2Vstsx2, + DeinterleavedMaterializeThenVsts, +}; + +struct VMIContiguousStoreSupport { + VMIContiguousStoreSupportKind kind = + VMIContiguousStoreSupportKind::ContiguousVsts; +}; + +enum class VMIContiguousLoadSupportKind { + ContiguousVlds, + LaneStride2UnpackedVlds, + LaneStride4UnpackedVlds, +}; + +struct VMIContiguousLoadSupport { + VMIContiguousLoadSupportKind kind = + VMIContiguousLoadSupportKind::ContiguousVlds; +}; + +enum class VMILayoutMaterializationSupportKind { + Identity, + ContiguousToDeinterleaved, + DeinterleavedToContiguous, + DeinterleavedToDeinterleavedViaContiguous, + ContiguousToLaneStrideViaUnpack, + LaneStrideToContiguousViaPack, +}; + +struct VMILayoutMaterializationSupport { + VMILayoutMaterializationSupportKind kind = + VMILayoutMaterializationSupportKind::Identity; +}; + +enum class VMIMaskGranularityMaterializationSupportKind { + Identity, + PredicateCast, +}; + +struct VMIMaskGranularityMaterializationSupport { + VMIMaskGranularityMaterializationSupportKind kind = + VMIMaskGranularityMaterializationSupportKind::Identity; +}; + +enum class VMICastLayoutKind { + Widen2x, + Widen4x, + Narrow2x, + Narrow4x, +}; + +struct VMICastLayoutFact { + VMICastLayoutKind kind = VMICastLayoutKind::Widen2x; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t sourceBits = 0; + int64_t resultBits = 0; + int64_t factor = 1; +}; + +enum class VMIGroupSlotLoadSupportKind { + Slots8UnitStrideVsldb, + Slots1AlignedLane0Vsldb, +}; + +struct VMIGroupSlotLoadSupport { + VMIGroupSlotLoadSupportKind kind = + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb; +}; + +enum class VMIGroupLoadSupportKind { + S16Block8Vsldb, + S32Block8Vsldb, +}; + +struct VMIGroupLoadSupport { + VMIGroupLoadSupportKind kind = VMIGroupLoadSupportKind::S16Block8Vsldb; +}; + +enum class VMIGroupSlotsStoreSupportKind { + Slots8UnitStrideVsts, + Slots1PointVsts, + Slots1PackedUnitStrideVsts, +}; + +struct VMIGroupSlotsStoreSupport { + VMIGroupSlotsStoreSupportKind kind = + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts; +}; + +enum class VMIGroupReduceLayoutKind { + OneVLane, + TwoVLane, + FourVLane, + RowLocal, +}; + +struct VMIGroupReduceLayoutFact { + VMIGroupReduceLayoutKind kind = VMIGroupReduceLayoutKind::OneVLane; + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize = 0; + int64_t lanesPerPart = 0; + int64_t vlaneElems = 0; +}; + +enum class VMIGroupReduceAddFSupportKind { + OneVLaneVcgadd, + TwoVLaneDeinterleaved2VcgaddVadd, + FourVLaneDeinterleaved4VcgaddTree, + ContiguousVcaddRows, +}; + +struct VMIGroupReduceAddFSupport { + VMIGroupReduceAddFSupportKind kind = + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd; +}; + +enum class VMIGroupBroadcastSupportKind { + GroupSlotsVselr, +}; + +struct VMIGroupBroadcastSupport { + VMIGroupBroadcastSupportKind kind = + VMIGroupBroadcastSupportKind::GroupSlotsVselr; +}; + +enum class VMIGroupBroadcastLoadSupportKind { + E2BVlds, + SlotLoadThenBroadcast, +}; + +struct VMIGroupBroadcastLoadSupport { + VMIGroupBroadcastLoadSupportKind kind = + VMIGroupBroadcastLoadSupportKind::E2BVlds; +}; + +enum class VMITruncFSupportKind { + Deinterleaved2F32ToContiguousF16, + Deinterleaved4F32ToContiguousF8, + ContiguousF32ToLaneStrideF16, + ContiguousF32ToLaneStrideF8, + GroupSlots1F32ToF16, +}; + +struct VMITruncFSupport { + VMITruncFSupportKind kind = + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16; +}; + +enum class VMIExtFSupportKind { + ContiguousF16ToDeinterleaved2F32, + ContiguousF8ToDeinterleaved4F32, +}; + +struct VMIExtFSupport { + VMIExtFSupportKind kind = + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32; +}; + +enum class VMITruncISupportKind { + Deinterleaved2I32ToContiguousI16, + Deinterleaved4I32ToContiguousI8, + ContiguousI32ToLaneStrideI16, + ContiguousI32ToLaneStrideI8, + GroupSlots1I32ToNarrow, +}; + +struct VMITruncISupport { + VMITruncISupportKind kind = + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16; +}; + +enum class VMIExtISupportKind { + ContiguousI16ToDeinterleaved2I32, + ContiguousI8ToDeinterleaved4I32, + GroupSlotsI16ToI32, + GroupSlotsI8ToI32, +}; + +struct VMIExtISupport { + VMIExtISupportKind kind = + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32; +}; + +enum class VMIBitcastSupportKind { + PerPartVbitcast, +}; + +struct VMIBitcastSupport { + VMIBitcastSupportKind kind = VMIBitcastSupportKind::PerPartVbitcast; +}; + +enum class VMIHistogramSupportKind { + Full256BinDhist, +}; + +struct VMIHistogramSupport { + VMIHistogramSupportKind kind = VMIHistogramSupportKind::Full256BinDhist; +}; + +class VMILayoutSupport { +public: + FailureOr + getContiguousStoreSupport(VMIVRegType valueType, + std::string *reason = nullptr) const; + + FailureOr + getContiguousLoadSupport(VMIVRegType resultType, + std::string *reason = nullptr) const; + + LogicalResult + canFoldContiguousStoreMaterialization(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + LogicalResult canFoldContiguousMaskedStoreMaterialization( + VMIVRegType sourceType, VMIMaskType maskSourceType, + VMIVRegType resultType, VMIMaskType maskResultType, + std::string *reason = nullptr) const; + + FailureOr + getDataLayoutMaterializationSupport(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskLayoutMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeMaskLayout(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskGranularityMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult + canMaterializeMaskGranularity(VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getWidenSourceLayoutForResultLayout(VMIVRegType sourceType, + VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, std::string *reason = nullptr) const; + + FailureOr + getGroupSlotsStoreSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, + std::string *reason = nullptr) const; + + FailureOr + getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceMaxFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceMaxFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddISupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceMaxISupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceMaxIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType sourceType, VMIVRegType resultType, + int64_t numGroups, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; + + FailureOr getExtFSupport(VMIExtFOp op, + std::string *reason = nullptr) const; + + FailureOr + getExtSISupport(VMIExtSIOp op, std::string *reason = nullptr) const; + + FailureOr + getExtUISupport(VMIExtUIOp op, std::string *reason = nullptr) const; + + FailureOr + getTruncISupport(VMITruncIOp op, std::string *reason = nullptr) const; + + FailureOr + getBitcastSupport(VMIBitcastOp op, std::string *reason = nullptr) const; + + FailureOr + getDhistSupport(VMIDhistOp op, std::string *reason = nullptr) const; + + FailureOr + getChistSupport(VMIChistOp op, std::string *reason = nullptr) const; +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMILAYOUTSUPPORT_H diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h new file mode 100644 index 0000000000..bb64f196e8 --- /dev/null +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -0,0 +1,348 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMITargetCapabilities.h - VMI target capability registry -*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMITARGETCAPABILITIES_H +#define PTO_TRANSFORMS_VMITARGETCAPABILITIES_H + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Twine.h" + +#include + +namespace mlir::pto { + +enum class VMICapabilityStatus { + supported, + unsupported_missing_capability, + unsupported_disabled_by_option, + unsupported_resource, +}; + +enum class VMIElementPurpose { + PredicateMask, + F16F32, + F16BF16F32, + SignlessOrSignedI8I16I32, + AnyI8I16I32, + VMula, + VRelu, +}; + +enum class VMIReductionKind { + AddI, + AddF, + GroupAddI, + GroupMaxI, + GroupAddF, + GroupMaxF, + MaxF, + MinF, +}; + +enum class VMIFallbackResourceKind { + ScratchMemory, + GuardedControlFlow, +}; + +struct VMICapabilityResult { + VMICapabilityStatus status = VMICapabilityStatus::supported; + std::string reason; + + static VMICapabilityResult supported() { return {}; } + + static VMICapabilityResult missingCapability(const Twine &reason) { + VMICapabilityResult result; + result.status = VMICapabilityStatus::unsupported_missing_capability; + result.reason = reason.str(); + return result; + } + + bool isSupported() const { return status == VMICapabilityStatus::supported; } + + LogicalResult toLogicalResult(std::string *outReason = nullptr) const { + if (isSupported()) + return success(); + if (outReason) + *outReason = reason; + return failure(); + } +}; + +class VMITargetCapabilityRegistry { +public: + VMICapabilityResult supportsElementType(Type type, + VMIElementPurpose purpose) const { + switch (purpose) { + case VMIElementPurpose::PredicateMask: { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type); + if (elementBits == 8 || elementBits == 16 || elementBits == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires an 8/16/32-bit element type so VPTO b8/b16/b32 " + "predicate masks can be materialized"); + } + case VMIElementPurpose::F16F32: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::F16BF16F32: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/bf16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::SignlessOrSignedI8I16I32: + if (isSignlessOrSignedI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed i8/i16/i32 element type for direct VPTO " + "lowering"); + case VMIElementPurpose::AnyI8I16I32: + if (isAnyI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed/unsigned i8/i16/i32 element type for " + "direct VPTO lowering"); + case VMIElementPurpose::VMula: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16, bf16, or f32 element type for pto.vmula"); + case VMIElementPurpose::VRelu: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "pto.vrelu direct lowering supports only f16/f32 VMI " + "floating-point element types"); + } + llvm_unreachable("unhandled VMI element purpose"); + } + + VMICapabilityResult supportsDirectMemory(Type type, StringRef role) const { + switch (classifyDirectMemoryRole(type)) { + case DirectMemoryRole::UB: + case DirectMemoryRole::Unknown: + return VMICapabilityResult::supported(); + case DirectMemoryRole::GM: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is GM-backed, but current direct VMI-to-VPTO memory lowering " + "emits pto.vlds/pto.vsts and requires UB-backed memory"); + case DirectMemoryRole::Other: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is not UB-backed memory supported by pto.vlds/pto.vsts"); + } + llvm_unreachable("unhandled direct memory role"); + } + + VMICapabilityResult supportsUBPointerMemory(Type type, StringRef role, + StringRef physicalOp, + StringRef ubReason) const { + auto ptrType = dyn_cast(type); + if (!ptrType) + return VMICapabilityResult::missingCapability( + Twine("requires a !pto.ptr ") + role + " because " + physicalOp + + " is pointer-only"); + if (ptrType.getMemorySpace().getAddressSpace() != AddressSpace::VEC) + return VMICapabilityResult::missingCapability( + Twine("requires a UB ") + role + " because " + ubReason); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsChannelCount(StringRef opName, + int64_t channels) const { + if (channels == 2 || channels == 4) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + Twine(opName) + " supports only 2 or 4 channels"); + } + + VMICapabilityResult supportsLayoutConversion(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + Type elementType) const { + (void)elementType; + if (!sourceLayout || !resultLayout) + return VMICapabilityResult::missingCapability( + "requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMICapabilityResult::supported(); + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && resultLayout.getLaneStride() == 1 && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + if (sourceLayout.isDeinterleaved() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && + resultLayout.getLaneStride() == 1 && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "unsupported source/result layout pair"); + } + + VMICapabilityResult + supportsMaskGranularityConversion(StringRef sourceGranularity, + StringRef resultGranularity) const { + if (!VMIMaskType::isConcreteGranularity(sourceGranularity) || + !VMIMaskType::isConcreteGranularity(resultGranularity)) + return VMICapabilityResult::missingCapability( + "requires concrete b8/b16/b32 source and result granularities"); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsTrueMaskedLoad(Type sourceType, Type resultType, + Type maskType) const { + (void)sourceType; + (void)resultType; + (void)maskType; + return VMICapabilityResult::missingCapability( + "target true masked/non-faulting load is unavailable because the " + "current VPTO pto.vlds surface has no mask operand"); + } + + VMICapabilityResult + supportsFallbackResource(VMIFallbackResourceKind kind) const { + switch (kind) { + case VMIFallbackResourceKind::ScratchMemory: + return VMICapabilityResult::missingCapability( + "scratch memory fallback resource allocation is not implemented"); + case VMIFallbackResourceKind::GuardedControlFlow: + return VMICapabilityResult::missingCapability( + "guarded memory fallback control-flow lowering is not implemented"); + } + llvm_unreachable("unhandled VMI fallback resource kind"); + } + + VMICapabilityResult supportsReductionElementType(VMIReductionKind kind, + Type elementType) const { + switch (kind) { + case VMIReductionKind::AddI: + if (pto::getPTOStorageElemBitWidth(elementType) == 32 && + isa(elementType)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only 32-bit integer elements because narrow " + "vcadd widens its result"); + case VMIReductionKind::AddF: + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements for floating-point " + "reduction"); + case VMIReductionKind::GroupAddI: + case VMIReductionKind::GroupMaxI: { + auto intType = dyn_cast(elementType); + if (intType && intType.getWidth() == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "grouped integer reduction supports only i32 accumulator " + "elements because narrow integer reductions widen their result; " + "cast i8/i16 storage before grouped reduction"); + } + case VMIReductionKind::GroupAddF: + case VMIReductionKind::GroupMaxF: + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "grouped floating-point reduction supports f16/f32 accumulator " + "elements"); + case VMIReductionKind::MaxF: + case VMIReductionKind::MinF: + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements because pto.vcmax/" + "pto.vcmin support only those floating-point element types"); + } + llvm_unreachable("unhandled VMI reduction kind"); + } + +private: + enum class DirectMemoryRole { Unknown, UB, GM, Other }; + + DirectMemoryRole classifyDirectMemoryRole(Type type) const { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + auto memrefType = dyn_cast(type); + if (!memrefType) + return DirectMemoryRole::Other; + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return DirectMemoryRole::Unknown; + + if (auto addressSpace = dyn_cast(memorySpace)) { + switch (addressSpace.getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + if (auto integerSpace = dyn_cast(memorySpace)) { + switch (integerSpace.getInt()) { + case static_cast(AddressSpace::GM): + case static_cast(AddressSpace::Zero): + return DirectMemoryRole::GM; + case static_cast(AddressSpace::VEC): + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + return DirectMemoryRole::Other; + } + + static bool isSignlessOrSignedI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType || intType.isUnsigned()) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } + + static bool isAnyI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMITARGETCAPABILITIES_H diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index 74b9e0bd68..4f8d995796 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(PTOIR PTO.cpp VPTO.cpp + VMI.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index d8db178278..8fe898a76f 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/InliningUtils.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -118,6 +119,27 @@ static bool isKnownZeroOrUnitExtent(int64_t value); static bool isByteIntegerType(Type ty); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, bool allowLowPrecision = false); + +namespace { +struct PTOInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); @@ -2597,6 +2619,8 @@ void PTODialect::initialize() { #define GET_ATTRDEF_LIST #include "PTO/IR/PTOAttrs.cpp.inc" >(); + + addInterfaces(); } diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp new file mode 100644 index 0000000000..a0d4d3dabd --- /dev/null +++ b/lib/PTO/IR/VMI.cpp @@ -0,0 +1,2183 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMI.cpp - PTO VMI type and attribute support -----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static std::string formatVMIVRegType(int64_t elementCount, Type elementType, + Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.vreg<" << elementCount << "x" << elementType; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static std::string formatVMIMaskType(int64_t elementCount, + StringRef granularity, Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.mask<" << elementCount << "x" << granularity; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static bool isSupportedVMIElementType(Type type) { + return isa(type) || + pto::isPTOLowPrecisionType(type); +} + +static bool isVMIFloatLikeType(Type type) { + return isa(type) || pto::isPTOLowPrecisionType(type); +} + +static bool isVMIIntegerLikeType(Type type) { + return isa(type); +} + +static bool isVMISignedOrSignlessIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && !integerType.isUnsigned(); +} + +static bool isVMIUnsignedIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && integerType.isUnsigned(); +} + +static bool isVMIIotaElementType(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return type.isF16() || type.isF32(); +} + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || + semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + +static unsigned getVMIElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static std::optional getVMIIntegerOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return std::nullopt; +} + +static int64_t divideCeilNonNegative(int64_t value, int64_t divisor) { + return value == 0 ? 0 : (value + divisor - 1) / divisor; +} + +static LogicalResult parseOptionalVMILayout(AsmParser &parser, + Attribute &layout) { + if (failed(parser.parseOptionalComma())) + return success(); + + if (failed(parser.parseAttribute(layout))) + return failure(); + if (!mlir::isa(layout)) + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.vmi.layout attribute"); + return success(); +} + +static FailureOr getVMIElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getAssignedVMILayout(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr; +} + +static FailureOr getLayoutFactor(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; +} + +static FailureOr getLayoutBlockElems(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; +} + +static FailureOr getVMIPhysicalElementType(VMIVRegType type) { + Type elementType = type.getElementType(); + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.hasGroupSlotLaneStride()) + return elementType; + + auto integerType = dyn_cast(elementType); + if (!integerType || !integerType.isUnsigned()) + return failure(); + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + int64_t laneStride = layout.getLaneStride(); + if (elementBits == 0 || laneStride <= 1) + return failure(); + int64_t physicalBits = static_cast(elementBits) * laneStride; + if (physicalBits != 16 && physicalBits != 32) + return failure(); + return IntegerType::get(type.getContext(), physicalBits); +} + +static FailureOr getPhysicalLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) { + FailureOr physicalElementType = getVMIPhysicalElementType(vregType); + if (failed(physicalElementType)) + return failure(); + return getDataLanesPerPart(*physicalElementType); + } + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static FailureOr getDenseLaneStride(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDense() ? (*layout).getLaneStride() : 1; +} + +static int64_t getMaskGranularityBitWidth(StringRef granularity) { + if (granularity == "b8") + return 8; + if (granularity == "b16") + return 16; + if (granularity == "b32") + return 32; + return 0; +} + +static bool isLayoutAssigned(VMIVRegType type) { + return static_cast(type.getLayoutAttr()); +} + +static bool isLayoutAssigned(VMIMaskType type) { + return static_cast(type.getLayoutAttr()); +} + +static LogicalResult +verifyAllSameVRegShapeAndLayout(Operation *op, ArrayRef types, + bool requireSameElement) { + if (types.empty()) + return success(); + + VMIVRegType first = types.front(); + bool anyLayout = llvm::any_of( + types, [](VMIVRegType type) { return isLayoutAssigned(type); }); + + for (VMIVRegType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError( + "requires all VMI data values to have the same logical lane count"); + if (requireSameElement && type.getElementType() != first.getElementType()) + return op->emitOpError( + "requires all VMI data values to have the same element type"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError( + "requires either all or no VMI data values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError("requires all layout-assigned VMI data values to " + "have the same layout"); + } + return success(); +} + +static LogicalResult verifyElementwiseVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, + VMIVRegType result) { + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatUnaryVRegOp(Operation *op, VMIVRegType source, + VMIVRegType result) { + if (!isVMIFloatLikeType(source.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {source, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatTernaryVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, VMIVRegType acc, + VMIVRegType result) { + if (!isVMIFloatLikeType(lhs.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, acc, result}, + /*requireSameElement=*/true); +} + +static LogicalResult +verifyAllSameMaskShapeLayoutAndGranularity(Operation *op, + ArrayRef types) { + if (types.empty()) + return success(); + + VMIMaskType first = types.front(); + bool anyLayout = llvm::any_of( + types, [](VMIMaskType type) { return isLayoutAssigned(type); }); + + for (VMIMaskType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError( + "requires all VMI mask values to have the same logical lane count"); + if (type.getGranularity() != first.getGranularity()) + return op->emitOpError( + "requires all VMI mask values to have the same granularity"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError( + "requires either all or no VMI mask values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError( + "requires all layout-assigned VMI mask values to have the same " + "layout"); + } + return success(); +} + +static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, + VMIVRegType dataType) { + if (maskType.getElementCount() != dataType.getElementCount()) + return op->emitOpError( + "requires mask logical lane count to match data lane count"); + + if (isLayoutAssigned(maskType) || isLayoutAssigned(dataType)) { + if (!isLayoutAssigned(maskType) || !isLayoutAssigned(dataType)) + return op->emitOpError("requires either both mask and data to carry " + "layout or neither to carry layout"); + if (maskType.getLayout() != dataType.getLayout()) + return op->emitOpError("requires mask layout to match data layout"); + } + + if (maskType.isPred()) + return success(); + + unsigned elementBitWidth = getVMIElementBitWidth(dataType.getElementType()); + int64_t maskBitWidth = getMaskGranularityBitWidth(maskType.getGranularity()); + if (elementBitWidth != 0 && maskBitWidth != 0 && + elementBitWidth != static_cast(maskBitWidth)) + return op->emitOpError( + "requires mask granularity to match data element width"); + + return success(); +} + +static Type getMemoryElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, + VMIVRegType dataType, + StringRef role) { + Type memoryElementType = getMemoryElementType(memoryType); + if (!memoryElementType) + return success(); + if (memoryElementType != dataType.getElementType()) + return op->emitOpError() << "requires memory " << role + << " element type to match VMI data element type"; + return success(); +} + +static LogicalResult verifyContiguousIfLayoutAssigned(Operation *op, + VMIVRegType type, + StringRef role) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (layout && !layout.isContiguous()) + return op->emitOpError() + << "requires layout-assigned " << role + << " to use #pto.vmi.layout"; + return success(); +} + +static bool isPackedByteGroupStore(Type memoryType, VMIVRegType dataType) { + Type memoryElementType = getMemoryElementType(memoryType); + if (!memoryElementType) + return false; + auto memoryIntegerType = dyn_cast(memoryElementType); + auto dataIntegerType = dyn_cast(dataType.getElementType()); + return memoryIntegerType && dataIntegerType && + memoryIntegerType.getWidth() == 8 && dataIntegerType.getWidth() == 32; +} + +static LogicalResult verifyNumGroups(Operation *op, VMIVRegType type, + int64_t numGroups) { + if (numGroups <= 0) + return op->emitOpError("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return op->emitOpError() + << "requires num_groups to evenly divide VMI logical lane count " + << type.getElementCount(); + return success(); +} + +static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, + TypeRange physicalTypes) { + FailureOr expectedArity = getVMIPhysicalArity(vmiType); + if (failed(expectedArity)) + return op->emitOpError( + "requires a layout-assigned VMI type with computable physical arity"); + if (static_cast(physicalTypes.size()) != *expectedArity) + return op->emitOpError() << "requires " << *expectedArity + << " physical parts, got " << physicalTypes.size(); + + if (auto vregType = dyn_cast(vmiType)) { + FailureOr lanesPerPart = + getPhysicalLanesPerPart(vregType); + FailureOr physicalElementType = getVMIPhysicalElementType(vregType); + if (failed(lanesPerPart) || failed(physicalElementType)) + return op->emitOpError( + "requires data element type with known physical lane count"); + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical data parts to be !pto.vreg"); + if (partType.getElementCount() != *lanesPerPart || + partType.getElementType() != *physicalElementType) + return op->emitOpError( + "requires physical data part type to match VMI lane-map helper"); + } + return success(); + } + + auto maskType = dyn_cast(vmiType); + if (!maskType) + return op->emitOpError("requires VMI data or mask type"); + if (maskType.isPred()) + return op->emitOpError( + "requires layout-assigned mask with concrete granularity"); + + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical mask parts to be !pto.mask"); + if (partType.getGranularity() != maskType.getGranularity()) + return op->emitOpError( + "requires physical mask part granularity to match VMI mask"); + } + return success(); +} + +static std::optional +mapDenseLogicalLaneToPartIndex(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t logicalLane, + int64_t &part) { + if (logicalLane < 0 || logicalLane >= elementCount || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t block = logicalLane / blockElems; + int64_t inBlockLane = logicalLane % blockElems; + part = block % factor; + int64_t partBlock = block / factor; + return partBlock * blockElems + inBlockLane; +} + +static std::optional +mapDensePartIndexToLogicalLane(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part, + int64_t indexInPart) { + if (part < 0 || part >= factor || indexInPart < 0 || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t partBlock = indexInPart / blockElems; + int64_t inBlockLane = indexInPart % blockElems; + int64_t logicalBlock = partBlock * factor + part; + int64_t logicalLane = logicalBlock * blockElems + inBlockLane; + if (logicalLane >= elementCount) + return std::nullopt; + return logicalLane; +} + +static int64_t getDenseLogicalLanesInPart(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part) { + int64_t maxIndex = -1; + for (int64_t lane = 0; lane < elementCount; ++lane) { + int64_t lanePart = 0; + std::optional index = mapDenseLogicalLaneToPartIndex( + elementCount, factor, blockElems, lane, lanePart); + if (index && lanePart == part) + maxIndex = std::max(maxIndex, *index); + } + return maxIndex + 1; +} + +} // namespace + +VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context, + int64_t laneStride) { + return VMILayoutAttr::get(context, "contiguous", 1, 1, 0, laneStride); +} + +VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, + int64_t factor, + int64_t blockElems, + int64_t laneStride) { + return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0, + laneStride); +} + +VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, + int64_t numGroups, int64_t slots, + int64_t laneStride) { + return VMILayoutAttr::get(context, "num_groups", numGroups, 1, slots, + laneStride); +} + +Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { + SMLoc loc = parser.getCurrentLocation(); + StringRef kind; + int64_t factor = 1; + int64_t blockElems = 1; + int64_t slots = 0; + int64_t laneStride = 1; + + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) + return {}; + + if (kind == "contiguous") { + factor = 1; + while (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual()) || + field != "lane_stride" || failed(parser.parseInteger(laneStride))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'lane_stride = '"); + return {}; + } + } + } else if (kind == "deinterleaved") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; + while (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual())) + return {}; + if (field == "block_elems") { + if (failed(parser.parseInteger(blockElems))) + return {}; + } else if (field == "lane_stride") { + if (failed(parser.parseInteger(laneStride))) + return {}; + } else { + parser.emitError(parser.getCurrentLocation(), + "expected 'block_elems = ' or " + "'lane_stride = '"); + return {}; + } + } + } else if (kind == "num_groups") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; + while (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual())) + return {}; + if (field == "slots") { + if (failed(parser.parseInteger(slots))) + return {}; + } else if (field == "lane_stride") { + if (failed(parser.parseInteger(laneStride))) + return {}; + } else { + parser.emitError(parser.getCurrentLocation(), + "expected 'slots = ' or " + "'lane_stride = '"); + return {}; + } + } + } else { + parser.emitError(parser.getCurrentLocation(), + "expected VMI layout kind 'contiguous' or " + "'deinterleaved' or 'num_groups'"); + return {}; + } + + if (failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), kind, + factor, blockElems, slots, + laneStride); +} + +void VMILayoutAttr::print(AsmPrinter &printer) const { + printer << "<" << getKind(); + if (isContiguous()) { + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); + } else if (isDeinterleaved()) { + printer << " = " << getFactor(); + if (getBlockElems() != 1) + printer << ", block_elems = " << getBlockElems(); + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); + } else if (isGroupSlots()) { + printer << " = " << getFactor(); + if (getSlots() != 0) + printer << ", slots = " << getSlots(); + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); + } + printer << ">"; +} + +LogicalResult +VMILayoutAttr::verify(function_ref emitError, + StringRef kind, int64_t factor, int64_t blockElems, + int64_t slots, int64_t laneStride) { + if (laneStride <= 0) + return emitError() << "#pto.vmi.layout<" << kind + << "> requires lane_stride to be positive"; + + if (kind == "contiguous") { + if (factor != 1 || blockElems != 1 || slots != 0) + return emitError() + << "#pto.vmi.layout requires factor, block_elems, " + "and slots to be their defaults"; + return success(); + } + + if (kind == "deinterleaved") { + if (factor != 2 && factor != 4) + return emitError() << "#pto.vmi.layout expected factor to be 2 or 4"; + if (blockElems <= 0) + return emitError() << "#pto.vmi.layout requires block_elems to be positive"; + if (slots != 0) + return emitError() << "#pto.vmi.layout requires slots to be omitted"; + return success(); + } + + if (kind == "num_groups") { + if (factor <= 0) + return emitError() << "#pto.vmi.layout requires num_groups to be positive"; + if (blockElems != 1) + return emitError() << "#pto.vmi.layout requires block_elems to be omitted"; + if (slots < 0) + return emitError() << "#pto.vmi.layout requires slots to be omitted or positive"; + return success(); + } + + return emitError() << "expected VMI layout kind to be 'contiguous' or " + "'deinterleaved' or 'num_groups'"; +} + +Type VMIVRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType, layout); +} + +void VMIVRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIVRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) + << "' expected a positive element count"; + + if (!isSupportedVMIElementType(elementType)) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) + << "' expected an integer, index, floating-point, or " + "PTO low-precision element type"; + if (pto::isPTOFloat4PackedType(elementType)) + return emitError() + << "'" << formatVMIVRegType(elementCount, elementType, layout) + << "' uses a packed FP4 physical pair type as a VMI logical " + "element type; packed FP4 input/output is not a supported VMI " + "surface because the logical FP4 lane count and physical packed " + "byte count are ambiguous"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) + << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots() && + elementCount != layoutAttr.getNumGroups()) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) + << "' expected num_groups layout to describe exactly " + "one logical result lane per group"; + } + + return success(); +} + +bool VMIMaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "pred" || isConcreteGranularity(granularity); +} + +bool VMIMaskType::isConcreteGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || granularity == "b32"; +} + +Type VMIMaskType::parse(AsmParser &parser) { + SmallVector shape; + StringRef granularity; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseKeyword(&granularity)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), shape.front(), + granularity, layout); +} + +void VMIMaskType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x" << getGranularity(); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIMaskType::verify(function_ref emitError, + int64_t elementCount, StringRef granularity, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' expected a positive element count"; + + if (!isSupportedGranularity(granularity)) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' expected granularity to be one of pred, b8, b16, " + "b32"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots()) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' mask type must not carry num_groups layout"; + } + + if (granularity == "pred" && layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' pred mask must not carry layout"; + + if (granularity != "pred" && !layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) + << "' concrete mask granularity requires layout"; + + return success(); +} + +LogicalResult VMIConstantOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements constant attribute"); + if (denseAttr.getElementType() != resultType.getElementType()) + return emitOpError( + "requires dense constant element type to match result element type"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense constant element count to match result " + "logical lane count"); + return success(); +} + +LogicalResult VMIBroadcastOp::verify() { + auto resultType = cast(getResult().getType()); + Type valueType = getValue().getType(); + if (valueType == resultType.getElementType()) + return success(); + if (auto vregType = dyn_cast(valueType)) { + if (vregType.getElementCount() != 1) + return emitOpError("requires VMI vector input to have one logical lane"); + if (vregType.getElementType() != resultType.getElementType()) + return emitOpError("requires VMI vector input element type to match " + "result element type"); + return success(); + } + return emitOpError("requires scalar or VMI vector input element type to " + "match result element type"); +} + +LogicalResult VMIIotaOp::verify() { + auto resultType = cast(getResult().getType()); + Type elementType = resultType.getElementType(); + if (!isVMIIotaElementType(elementType)) + return emitOpError("requires result element type to be integer 8/16/32 " + "or f16/f32"); + if (!isCompatibleScalarForSemanticType(elementType, getBase().getType())) + return emitOpError("requires base type to match result element type"); + + if (std::optional order = getOrder()) { + if (*order != "ASC" && *order != "DESC") + return emitOpError("requires order to be ASC or DESC"); + } + return success(); +} + +LogicalResult VMICreateMaskOp::verify() { + auto resultType = cast(getResult().getType()); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + +LogicalResult VMICreateGroupMaskOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + int64_t groupSize = getGroupSizeAttr().getInt(); + if (numGroups <= 0) + return emitOpError("requires positive num_groups"); + if (groupSize <= 0) + return emitOpError("requires positive group_size"); + if (resultType.getElementCount() != numGroups * groupSize) + return emitOpError("requires result lane count to equal num_groups * " + "group_size"); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + +LogicalResult VMIConstantMaskOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements mask constant attribute"); + if (!denseAttr.getElementType().isInteger(1)) + return emitOpError("requires dense mask constant element type to be i1"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense mask constant element count to match " + "result logical lane count"); + return success(); +} + +LogicalResult VMIMaskAndOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskXOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskNotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity(getOperation(), + {sourceType, resultType}); +} + +LogicalResult VMIAddFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIAddIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIFmaOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto accType = cast(getAcc().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatTernaryVRegOp(getOperation(), lhsType, rhsType, accType, + resultType); +} + +LogicalResult VMIDivFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMinFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMaxFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINegFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMISqrtOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIExpOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMILnOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIReluOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAndIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIXOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShLIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShRUIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + auto integerType = dyn_cast(lhsType.getElementType()); + if (!integerType || integerType.isSigned()) + return emitOpError( + "requires signless or unsigned integer VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMICmpFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMICmpIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMISelectOp::verify() { + auto maskType = cast(getMask().getType()); + auto trueType = cast(getTrueValue().getType()); + auto falseType = cast(getFalseValue().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {trueType, falseType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMIActivePrefixIndexOp::verify() { + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + auto resultIntType = dyn_cast(resultType.getElementType()); + if (!resultIntType || !resultIntType.isSignless()) + return emitOpError("requires signless integer result element type"); + unsigned resultWidth = resultIntType.getWidth(); + if (resultWidth != 8 && resultWidth != 16 && resultWidth != 32) + return emitOpError("requires i8, i16, or i32 result element type"); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMICompressOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMICompressStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMICompressStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIReduceAddIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be 1-lane VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceAddFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!getOperation()->hasAttr("reassoc")) + return emitOpError( + "requires reassoc attr because VPTO vcadd performs pair-wise " + "floating-point reduction"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be 1-lane VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +template LogicalResult verifyReduceMinMaxFOp(OpTy op) { + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return op.emitOpError( + "requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return op.emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return op.emitOpError("requires init and result to be 1-lane VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(op.getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(op.getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceMaxFOp::verify() { return verifyReduceMinMaxFOp(*this); } + +LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } + +template +static LogicalResult verifyGroupReduceFloatOp(OpTy op, bool requiresReassoc) { + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (requiresReassoc && !op->hasAttr("reassoc")) + return op.emitOpError( + "requires reassoc attr because grouped lowering uses pair-wise " + "floating-point reductions"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return op.emitOpError( + "requires floating-point-like VMI source element type"); + if (resultType.getElementCount() != op.getNumGroupsAttr().getInt()) + return op.emitOpError( + "requires result logical lane count to match num_groups"); + if (sourceType.getElementType() != resultType.getElementType()) + return op.emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) + return op.emitOpError( + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return op.emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(op.getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(op.getOperation(), sourceType, + op.getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupReduceAddFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/true); +} + +LogicalResult VMIGroupReduceMaxFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/false); +} + +template +static LogicalResult verifyGroupReduceIntegerOp(OpTy op) { + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return op.emitOpError("requires integer-like VMI source element type"); + auto intType = dyn_cast(sourceType.getElementType()); + if (!intType || intType.getWidth() != 32) + return op.emitOpError( + "requires i32 accumulator element type; cast i8/i16 storage to i32 " + "before grouped reduction because integer reduction widens narrow " + "inputs"); + if (resultType.getElementCount() != op.getNumGroupsAttr().getInt()) + return op.emitOpError( + "requires result logical lane count to match num_groups"); + if (sourceType.getElementType() != resultType.getElementType()) + return op.emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) + return op.emitOpError( + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return op.emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(op.getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(op.getOperation(), sourceType, + op.getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupReduceAddIOp::verify() { + return verifyGroupReduceIntegerOp(*this); +} + +LogicalResult VMIGroupReduceMaxIOp::verify() { + return verifyGroupReduceIntegerOp(*this); +} + +LogicalResult VMIGroupBroadcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + if (sourceType.getElementCount() != numGroups) + return emitOpError( + "requires source logical lane count to match num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return emitOpError( + "requires num_groups to evenly divide result logical lane count"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != numGroups) + return emitOpError() << "requires layout-assigned source to use " + "#pto.vmi.layout"; + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (resultLayout.isGroupSlots()) + return emitOpError( + "requires layout-assigned result to use a dense VMI layout"); + } + return verifyNumGroups(getOperation(), resultType, numGroups); +} + +template static LogicalResult verifyVMIHistogramOp(OpTy op) { + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + auto accElemType = dyn_cast(accType.getElementType()); + auto sourceElemType = dyn_cast(sourceType.getElementType()); + if (!accElemType || !accElemType.isUnsigned() || + accElemType.getWidth() != 16 || accType.getElementCount() != 256) + return op.emitOpError("requires acc type to be " + "!pto.vmi.vreg<256xui16>"); + if (resultType != accType) + return op.emitOpError("requires result type to match acc type"); + if (!sourceElemType || !sourceElemType.isUnsigned() || + sourceElemType.getWidth() != 8) + return op.emitOpError("requires source type to be " + "!pto.vmi.vreg"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return op.emitOpError("requires mask logical lane count to match source"); + + if (auto accLayout = accType.getLayoutAttr()) { + if (!accLayout.isContiguous()) + return op.emitOpError("requires layout-assigned acc to use contiguous " + "layout"); + } + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isContiguous()) + return op.emitOpError("requires layout-assigned source to use contiguous " + "layout"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isContiguous()) + return op.emitOpError("requires layout-assigned result to use " + "contiguous layout"); + } + if (auto maskLayout = maskType.getLayoutAttr()) { + if (!maskLayout.isContiguous()) + return op.emitOpError("requires layout-assigned mask to use contiguous " + "layout"); + if (maskType.getGranularity() != "b8") + return op.emitOpError("requires layout-assigned mask granularity b8"); + } + return success(); +} + +LogicalResult VMIDhistOp::verify() { return verifyVMIHistogramOp(*this); } + +LogicalResult VMIChistOp::verify() { return verifyVMIHistogramOp(*this); } + +LogicalResult VMIExtFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError( + "requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError( + "requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be narrower than source element type"); + if (auto roundingAttr = (*this)->getAttrOfType("rounding")) { + StringRef rounding = roundingAttr.getValue(); + if (rounding != "A" && rounding != "H") + return emitOpError("rounding attr must be A or H"); + if (!sourceType.getElementType().isF32() || + !pto::isPTOHiFloat8Type(resultType.getElementType())) + return emitOpError( + "rounding attr is currently only supported for f32 to !pto.hif8 " + "truncf"); + } + return success(); +} + +LogicalResult VMIFPToSIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like source element type"); + if (!isVMISignedOrSignlessIntegerType(resultType.getElementType())) + return emitOpError("requires signed or signless integer result element " + "type"); + if (getVMIElementBitWidth(resultType.getElementType()) != 32) + return emitOpError("requires 32-bit integer result element type"); + return success(); +} + +LogicalResult VMISIToFPOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMISignedOrSignlessIntegerType(sourceType.getElementType())) + return emitOpError( + "requires signed or signless integer source element type"); + if (!isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like result element type"); + if (getVMIElementBitWidth(sourceType.getElementType()) != 32) + return emitOpError("requires 32-bit integer source element type"); + if (!resultType.getElementType().isF32()) + return emitOpError("requires f32 result element type"); + return success(); +} + +LogicalResult VMIExtSIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMISignedOrSignlessIntegerType(sourceType.getElementType()) || + !isVMISignedOrSignlessIntegerType(resultType.getElementType())) + return emitOpError( + "requires signed or signless integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMIExtUIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIUnsignedIntegerType(sourceType.getElementType()) || + !isVMIUnsignedIntegerType(resultType.getElementType())) + return emitOpError( + "requires unsigned integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIIntegerLikeType(sourceType.getElementType()) || + !isVMIIntegerLikeType(resultType.getElementType())) + return emitOpError("requires integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be narrower than source element type"); + return success(); +} + +LogicalResult VMIBitcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + std::optional sourceBits = + getVMIIntegerOrFloatBitWidth(sourceType.getElementType()); + std::optional resultBits = + getVMIIntegerOrFloatBitWidth(resultType.getElementType()); + if (!sourceBits || !resultBits) + return emitOpError( + "requires integer or floating-point source and result element types"); + if (sourceType.getElementCount() * static_cast(*sourceBits) != + resultType.getElementCount() * static_cast(*resultBits)) + return emitOpError( + "requires source and result to carry the same total number of bits"); + + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError( + "requires either both source and result to carry layout or neither " + "to carry layout"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result layouts to match"); + } + + return success(); +} + +LogicalResult VMILoadOp::verify() { + return verifyMemoryElementMatches(getOperation(), getSource().getType(), + cast(getResult().getType()), + "source"); +} + +void VMILoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIDeinterleaveLoadOp::verify() { + auto lowType = cast(getLow().getType()); + auto highType = cast(getHigh().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {lowType, highType}, + /*requireSameElement=*/true))) + return failure(); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + lowType, "source"))) + return failure(); + if (failed(verifyContiguousIfLayoutAssigned(getOperation(), lowType, + "low result")) || + failed(verifyContiguousIfLayoutAssigned(getOperation(), highType, + "high result"))) + return failure(); + return success(); +} + +void VMIDeinterleaveLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGroupLoadOp::verify() { + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + return verifyNumGroups(getOperation(), resultType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGroupSlotLoadOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + if (resultType.getElementCount() != numGroups) + return emitOpError( + "requires result logical lane count to match num_groups"); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups) + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + return verifyNumGroups(getOperation(), resultType, numGroups); +} + +void VMIGroupSlotLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGroupBroadcastLoadOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + if (numGroups <= 0) + return emitOpError("requires num_groups to be positive"); + if (resultType.getElementCount() % numGroups != 0) + return emitOpError( + "requires num_groups to evenly divide result logical lane count"); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (auto resultLayout = resultType.getLayoutAttr()) { + if (resultLayout.isGroupSlots()) + return emitOpError( + "requires layout-assigned result to use a dense VMI layout"); + } + return verifyNumGroups(getOperation(), resultType, numGroups); +} + +void VMIGroupBroadcastLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIMaskedLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIMaskedLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGatherOp::verify() { + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.isSigned() || + (indexElementType.getWidth() != 16 && indexElementType.getWidth() != 32)) + return emitOpError( + "requires signless or unsigned 16-bit or 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {indicesType, passthruType, resultType}, + /*requireSameElement=*/false))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + + auto resultIntegerType = dyn_cast(resultType.getElementType()); + if (indexElementType.getWidth() == 16 && + (!resultIntegerType || !resultIntegerType.isUnsigned() || + resultIntegerType.getWidth() != 16)) + return emitOpError( + "requires ui16 result and passthru element type when using ui16 " + "indices"); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIGatherOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIExpandLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIExpandLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIStoreOp::verify() { + return verifyMemoryElementMatches(getOperation(), getDestination().getType(), + cast(getValue().getType()), + "destination"); +} + +void VMIStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIInterleaveStoreOp::verify() { + auto lowType = cast(getLow().getType()); + auto highType = cast(getHigh().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {lowType, highType}, + /*requireSameElement=*/true))) + return failure(); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), lowType, + "destination"))) + return failure(); + if (failed(verifyContiguousIfLayoutAssigned(getOperation(), lowType, + "low input")) || + failed(verifyContiguousIfLayoutAssigned(getOperation(), highType, + "high input"))) + return failure(); + return success(); +} + +void VMIInterleaveStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIGroupStoreOp::verify() { + auto valueType = cast(getValue().getType()); + if (!isPackedByteGroupStore(getDestination().getType(), valueType) && + failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyNumGroups(getOperation(), valueType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIStrideLoadOp::verify() { + auto resultType = cast(getResult().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIStrideLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIMaskedStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIMaskedStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIStrideStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIStrideStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIScatterOp::verify() { + auto valueType = cast(getValue().getType()); + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return emitOpError("requires signless or unsigned 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {valueType, indicesType}, + /*requireSameElement=*/false))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIScatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIShuffleOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires result element type to match source element type"); + if (static_cast(getIndices().size()) != resultType.getElementCount()) + return emitOpError( + "requires shuffle index count to match result logical lane count"); + for (int64_t index : getIndices()) { + if (index < 0 || index >= sourceType.getElementCount()) + return emitOpError("requires every shuffle index to select an existing " + "source logical lane"); + } + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires either both source and result to carry " + "layout or neither to carry layout"); + } + return success(); +} + +LogicalResult VMIChannelSplitOp::verify() { + auto sourceType = cast(getSource().getType()); + if (getResults().size() < 2) + return emitOpError("requires at least two channel results"); + auto firstResultType = cast(getResults().front().getType()); + if (sourceType.getElementCount() != + static_cast(getResults().size()) * + firstResultType.getElementCount()) + return emitOpError("requires source lane count to equal result count times " + "per-channel lane count"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (resultType.getElementCount() != firstResultType.getElementCount() || + resultType.getElementType() != sourceType.getElementType()) + return emitOpError("requires every channel result to have equal lane " + "count and source element type"); + } + bool anyLayout = isLayoutAssigned(sourceType); + for (Value result : getResults()) + anyLayout |= isLayoutAssigned(cast(result.getType())); + if (anyLayout) { + if (!isLayoutAssigned(sourceType)) + return emitOpError("requires layout-assigned channel_split source when " + "any channel result has layout"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (!isLayoutAssigned(resultType)) + return emitOpError("requires every channel_split result to carry " + "layout when source has layout"); + if (!cast(resultType.getLayout()).isContiguous()) + return emitOpError( + "requires layout-assigned channel_split results to be contiguous"); + } + int64_t channels = getResults().size(); + if (channels == 2 || channels == 4) { + auto sourceLayout = cast(sourceType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_split source to " + "be contiguous or deinterleaved by result count"); + } + } + return success(); +} + +LogicalResult VMIChannelMergeOp::verify() { + if (getInputs().size() < 2) + return emitOpError("requires at least two channel inputs"); + auto firstInputType = cast(getInputs().front().getType()); + auto resultType = cast(getResult().getType()); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (inputType.getElementCount() != firstInputType.getElementCount() || + inputType.getElementType() != firstInputType.getElementType()) + return emitOpError("requires all channel inputs to have the same lane " + "count and element type"); + } + if (resultType.getElementCount() != static_cast(getInputs().size()) * + firstInputType.getElementCount() || + resultType.getElementType() != firstInputType.getElementType()) + return emitOpError( + "requires result lane count and element type to match merged channels"); + bool anyLayout = isLayoutAssigned(resultType); + for (Value input : getInputs()) + anyLayout |= isLayoutAssigned(cast(input.getType())); + if (anyLayout) { + if (!isLayoutAssigned(resultType)) + return emitOpError("requires layout-assigned channel_merge result when " + "any channel input has layout"); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (!isLayoutAssigned(inputType)) + return emitOpError("requires every channel_merge input to carry layout " + "when result has layout"); + if (!cast(inputType.getLayout()).isContiguous()) + return emitOpError( + "requires layout-assigned channel_merge inputs to be contiguous"); + } + int64_t channels = getInputs().size(); + if (channels == 2 || channels == 4) { + auto resultLayout = cast(resultType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_merge result to " + "be contiguous or deinterleaved by input count"); + } + } + return success(); +} + +LogicalResult VMIEnsureLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result to preserve VMI data shape " + "and element type"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getGranularity() != resultType.getGranularity()) + return emitOpError("requires source and result to preserve VMI mask shape " + "and granularity"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskGranularityOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result to preserve VMI mask lane count"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result mask layouts to match"); + if (sourceType.isPred() || resultType.isPred()) + return emitOpError( + "requires concrete source and result mask granularities"); + return success(); +} + +LogicalResult VMIUnpackOp::verify() { + return verifyPhysicalParts(getOperation(), getSource().getType(), + getParts().getTypes()); +} + +LogicalResult VMIPackOp::verify() { + return verifyPhysicalParts(getOperation(), getResult().getType(), + getParts().getTypes()); +} + +FailureOr mlir::pto::getDataLanesPerPart(Type elementType) { + unsigned elementBitWidth = pto::getPTOStorageElemBitWidth(elementType); + if (elementBitWidth == 0) + return failure(); + constexpr int64_t kPhysicalVRegBits = 256 * 8; + if (kPhysicalVRegBits % elementBitWidth != 0) + return failure(); + return kPhysicalVRegBits / elementBitWidth; +} + +FailureOr mlir::pto::getMaskLanesPerPart(StringRef granularity) { + if (granularity == "b8") + return 256; + if (granularity == "b16") + return 128; + if (granularity == "b32") + return 64; + return failure(); +} + +FailureOr mlir::pto::getVMIPhysicalArity(Type type) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + FailureOr layout = getAssignedVMILayout(type); + if (failed(elementCount) || failed(lanesPerPart) || failed(layout)) + return failure(); + + if ((*layout).isGroupSlots() && (*layout).getSlots() > 0) + return divideCeilNonNegative((*layout).getNumGroups(), + (*layout).getSlots()); + + int64_t factor = (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; + int64_t blockElems = + (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; + int64_t laneStride = (*layout).isDense() ? (*layout).getLaneStride() : 1; + int64_t arity = 0; + for (int64_t part = 0; part < factor; ++part) { + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, factor, blockElems, part); + int64_t requiredPhysicalLanes = + lanesInPart == 0 ? 0 : (lanesInPart - 1) * laneStride + 1; + arity += divideCeilNonNegative(requiredPhysicalLanes, *lanesPerPart); + } + return arity; +} + +FailureOr +mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); + FailureOr laneStride = getDenseLaneStride(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(laneStride) || failed(lanesPerPart)) + return failure(); + if (logicalLane < 0 || logicalLane >= *elementCount) + return failure(); + + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + int64_t lane = logicalLane % slots; + if (lane >= *lanesPerPart) + return failure(); + return VMIPhysicalLane{/*part=*/0, logicalLane / slots, lane}; + } + + int64_t part = 0; + std::optional indexInPart = mapDenseLogicalLaneToPartIndex( + *elementCount, *factor, *blockElems, logicalLane, part); + if (!indexInPart) + return failure(); + int64_t physicalIndex = *indexInPart * *laneStride; + return VMIPhysicalLane{part, physicalIndex / *lanesPerPart, + physicalIndex % *lanesPerPart}; +} + +FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, + int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); + FailureOr laneStride = getDenseLaneStride(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(laneStride) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + if (part != 0 || lane >= slots) + return failure(); + int64_t logicalLane = chunk * slots + lane; + if (logicalLane >= *elementCount) + return failure(); + return logicalLane; + } + + int64_t physicalIndexInPart = chunk * *lanesPerPart + lane; + if (physicalIndexInPart % *laneStride != 0) + return failure(); + int64_t indexInPart = physicalIndexInPart / *laneStride; + std::optional logicalLane = mapDensePartIndexToLogicalLane( + *elementCount, *factor, *blockElems, part, indexInPart); + if (!logicalLane) + return failure(); + return *logicalLane; +} + +FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); + FailureOr laneStride = getDenseLaneStride(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(laneStride) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + if (part != 0) + return true; + if (lane >= slots) + return true; + return chunk * slots + lane >= *elementCount; + } + + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, *factor, *blockElems, part); + int64_t physicalIndexInPart = chunk * *lanesPerPart + lane; + if (physicalIndexInPart % *laneStride != 0) + return true; + int64_t indexInPart = physicalIndexInPart / *laneStride; + return indexInPart >= lanesInPart; +} diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 4d9ae3401d..bcb7a61e21 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1728,6 +1728,10 @@ getVstsMaskGranularityOverride(StringRef dist, Type elementType) { return StringRef("b16"); if (dist == "PK_B32") return StringRef("b32"); + if (dist == "PK_B64") + return StringRef("b32"); + if (dist == "PK4_B32") + return StringRef("b32"); return std::nullopt; } diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index e372c3d711..706e0c3c35 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -30,10 +30,20 @@ add_mlir_dialect_library(PTOTransforms VPTOLLVMEmitterHelper.cpp VPTOPtrNormalize.cpp VPTOPtrCastCleanup.cpp + VPTONormalizeEquivalentVcvt.cpp VPTOExpandWrapperOps.cpp PTOVPTOPtrBoundary.cpp VPTOBufferMaterialization.cpp PTOValidateVPTOIR.cpp + PTOValidateVMIIR.cpp + VMIPreAssignmentCombine.cpp + VMILegalizeArithSelect.cpp + VMILayoutAssignment.cpp + VMILayoutFold.cpp + VMILayoutSupport.cpp + VMILayoutRematerialize.cpp + VMILayoutSinkMaterialization.cpp + VMIToVPTO.cpp PTOInferVPTOVecScope.cpp InsertSync/PTOInsertSync.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp new file mode 100644 index 0000000000..16a6b24393 --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -0,0 +1,773 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- PTOValidateVMIIR.cpp - VMI boundary verifier ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVALIDATEVMIIR +#define GEN_PASS_DEF_PTOVALIDATEVMILAYOUTIR +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +bool isVMIType(Type type) { return isa(type); } + +bool isPhysicalVPTOType(Type type) { + return isa(type); +} + +bool containsVMIOrPhysicalType(Type type) { + if (isVMIType(type) || isPhysicalVPTOType(type)) + return true; + + if (auto functionType = dyn_cast(type)) { + return llvm::any_of( + functionType.getInputs(), + [](Type input) { return containsVMIOrPhysicalType(input); }) || + llvm::any_of(functionType.getResults(), [](Type result) { + return containsVMIOrPhysicalType(result); + }); + } + + if (auto shapedType = dyn_cast(type)) + return containsVMIOrPhysicalType(shapedType.getElementType()); + + return false; +} + +bool containsVMIOrPhysicalType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, [](Attribute element) { + return containsVMIOrPhysicalType(element); + }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return containsVMIOrPhysicalType(namedAttr.getValue()); + }); + + return false; +} + +bool isSurfaceVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return !vregType.getLayout(); + if (auto maskType = dyn_cast(type)) + return maskType.isPred() && !maskType.getLayout(); + return false; +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return false; +} + +bool isVMIHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity" || name == "pto.vmi.pack" || + name == "pto.vmi.unpack"; +} + +bool isVMILayoutHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity"; +} + +bool isVMISemanticOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name.starts_with("pto.vmi.") && !isVMIHelperOp(op); +} + +bool isStructuralOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "builtin.module" || name.starts_with("func.") || + name.starts_with("scf.") || name.starts_with("cf."); +} + +bool hasVMIOrPhysicalType(Operation *op) { + auto hasInterestingType = [](Type type) { + return isVMIType(type) || isPhysicalVPTOType(type); + }; + if (llvm::any_of(op->getOperandTypes(), hasInterestingType) || + llvm::any_of(op->getResultTypes(), hasInterestingType)) + return true; + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + if (llvm::any_of(block.getArgumentTypes(), hasInterestingType)) + return true; + } + } + return false; +} + +void mirrorDiagnostic(llvm::raw_ostream *diagOS, Twine message) { + if (diagOS) + *diagOS << message << "\n"; +} + +LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = op->emitError() + << kVMIDiagPassInvariantPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagPassInvariantPrefix) + message); + return failure(); +} + +LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = op->emitError() + << kVMIDiagLayoutContractPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +LogicalResult emitLayoutSupportContract(Operation *op, + llvm::raw_ostream *diagOS, + Twine message, StringRef reason) { + std::string text; + llvm::raw_string_ostream os(text); + os << message << ": " << reason; + + bool printedAny = false; + auto printValueType = [&](StringRef kind, int64_t index, Type type) { + if (!isVMIType(type)) + return; + if (!printedAny) { + os << "; VMI types:"; + printedAny = true; + } + os << " " << kind << "#" << index << "=" << type; + }; + + for (auto [index, operand] : llvm::enumerate(op->getOperands())) + printValueType("operand", static_cast(index), operand.getType()); + for (auto [index, result] : llvm::enumerate(op->getResults())) + printValueType("result", static_cast(index), result.getType()); + + os.flush(); + return emitLayoutContract(op, diagOS, text); +} + +LogicalResult +emitHelperMaterializationContract(Operation *helper, Type sourceType, + Type resultType, StringRef helperName, + StringRef reason, llvm::raw_ostream *diagOS) { + auto emitFallback = [&]() { + return emitLayoutContract( + helper, diagOS, + Twine(helperName) + + " has no registered materialization support: " + reason); + }; + + if (helper->getNumResults() != 1 || !helper->getResult(0).hasOneUse()) + return emitFallback(); + + OpOperand &use = *helper->getResult(0).use_begin(); + Operation *requester = use.getOwner(); + std::string message; + llvm::raw_string_ostream os(message); + os << requester->getName() << " operand #" << use.getOperandNumber() + << " has type " << sourceType << " but requires " << resultType << "; " + << helperName << " has no registered materialization support: " << reason; + os.flush(); + + InFlightDiagnostic diag = requester->emitError() + << kVMIDiagLayoutContractPrefix << message; + diag.attachNote(helper->getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason << ")"; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +LogicalResult verifyBoundaryType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isSurfaceVMIType(type)) + return emitInvariant( + owner, diagOS, + "VMI producer boundary requires surface !pto.vmi.vreg or " + "!pto.vmi.mask type"); + + return success(); +} + +LogicalResult verifyBoundaryTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyBoundaryType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyBoundaryTypeTree(owner, shapedType.getElementType(), diagOS); + + return success(); +} + +LogicalResult verifyLayoutAssignedType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isLayoutAssignedVMIType(type)) + return emitInvariant( + owner, diagOS, + "layout-assigned VMI IR requires !pto.vmi.vreg with layout and " + "!pto.vmi.mask with b8/b16/b32 granularity plus layout"); + + return success(); +} + +LogicalResult verifyLayoutAssignedTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyLayoutAssignedType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedTypeTree(owner, shapedType.getElementType(), + diagOS); + + return success(); +} + +template +LogicalResult verifyAttributeTypes(Operation *owner, Attribute attr, + llvm::raw_ostream *diagOS, + TypeVerifier verifyType) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typeAttr.getValue(), diagOS))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typedAttr.getType(), diagOS))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyAttributeTypes(owner, element, diagOS, verifyType))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyAttributeTypes(owner, namedAttr.getValue(), diagOS, + verifyType))) + return failure(); + } + + return success(); +} + +bool isFunctionTypeAttr(Operation *op, NamedAttribute attr) { + return isa(op) && attr.getName() == "function_type"; +} + +LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, NamedAttribute attr, + llvm::raw_ostream *diagOS) { + if (isFunctionTypeAttr(op, attr)) + return success(); + if (containsVMIOrPhysicalType(attr.getValue())) + return emitInvariant( + op, diagOS, + "VMI or physical VPTO type appears in a non-signature attribute"); + return success(); +} + +LogicalResult verifyOperationTypes(Operation *op, llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyBoundaryTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, + llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyLayoutAssignedTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyLayoutHelperSupport(Operation *op, + llvm::raw_ostream *diagOS); + +LogicalResult verifyLayoutSemanticSupport(Operation *op, + llvm::raw_ostream *diagOS); + +LogicalResult verifyOperationBoundary(Operation *op, + llvm::raw_ostream *diagOS) { + if (failed(verifyOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) + return emitInvariant( + op, diagOS, + "VMI helper op appears before layout assignment or VMI-to-VPTO"); + + if (isVMISemanticOp(op) || isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +LogicalResult verifyLayoutAssignedOperation(Operation *op, + llvm::raw_ostream *diagOS, + bool verifyHelperSupports = true) { + if (failed(verifyLayoutAssignedOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) { + if (isVMILayoutHelperOp(op)) + return verifyHelperSupports ? verifyLayoutHelperSupport(op, diagOS) + : success(); + return emitInvariant( + op, diagOS, + "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); + } + + if (isVMISemanticOp(op)) + return verifyLayoutSemanticSupport(op, diagOS); + if (isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +LogicalResult verifyLayoutHelperSupport(Operation *op, + llvm::raw_ostream *diagOS) { + VMILayoutSupport supports; + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed( + supports.canMaterializeDataLayout(sourceType, resultType, &reason))) + return emitHelperMaterializationContract( + op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed( + supports.canMaterializeMaskLayout(sourceType, resultType, &reason))) + return emitHelperMaterializationContract(op, sourceType, resultType, + "pto.vmi.ensure_mask_layout", + reason, diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.ensure_mask_granularity has no registered " + "materialization support: ") + + reason); + return success(); + } + + return success(); +} + +LogicalResult verifyLayoutSemanticSupport(Operation *op, + llvm::raw_ostream *diagOS) { + VMILayoutSupport supports; + VMITargetCapabilityRegistry capabilities; + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || layout.isContiguous()) + return success(); + + std::string reason; + if (failed(supports.getContiguousStoreSupport(valueType, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.store has no registered contiguous-memory layout support", + reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return success(); + + std::string reason; + if (failed(supports.getGroupLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_load has no registered block8 layout support", reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + std::string reason; + if (failed(supports.getGroupSlotLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_slot_load has no registered layout support", reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + std::string reason; + if (failed( + supports.getGroupBroadcastLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_broadcast_load has no registered layout support", + reason); + return success(); + } + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_store has no registered group_slots layout support", + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceAddFSupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_addf has no registered group_slots layout " + "support", + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceMaxFSupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_maxf has no registered group_slots layout " + "support", + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceAddISupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_addi has no registered group_slots layout " + "support", + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceMaxISupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_maxi has no registered group_slots layout " + "support", + reason); + return success(); + } + + if (auto broadcast = dyn_cast(op)) { + auto sourceType = cast(broadcast.getSource().getType()); + VMILayoutAttr layout = sourceType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return success(); + + std::string reason; + if (failed(supports.getGroupBroadcastSupport(capabilities, broadcast, + &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_broadcast has no registered layout support", reason); + return success(); + } + + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getDhistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.dhist has no registered histogram support", + reason); + return success(); + } + + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getChistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.chist has no registered histogram support", + reason); + return success(); + } + + if (auto truncf = dyn_cast(op)) { + std::string reason; + if (failed(supports.getTruncFSupport(truncf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.truncf has no registered layout support", + reason); + return success(); + } + + if (auto extf = dyn_cast(op)) { + std::string reason; + if (failed(supports.getExtFSupport(extf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.extf has no registered layout support", reason); + return success(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (failed(supports.getBitcastSupport(bitcast, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.bitcast has no registered layout support", + reason); + return success(); + } + + return success(); +} + +struct PTOValidateVMIIRPass + : public mlir::pto::impl::PTOValidateVMIIRBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMIIRPass) + + void runOnOperation() override { + if (failed(validateVMIProducerBoundaryIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVMILayoutIRPass + : public mlir::pto::impl::PTOValidateVMILayoutIRBase< + PTOValidateVMILayoutIRPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMILayoutIRPass) + + void runOnOperation() override { + if (failed(validateVMILayoutAssignedIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult +mlir::pto::validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyOperationBoundary(op, diagOS))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult mlir::pto::validateVMILayoutAssignedIR( + ModuleOp module, llvm::raw_ostream *diagOS, bool verifyHelperSupports) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyLayoutAssignedOperation(op, diagOS, verifyHelperSupports))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +std::unique_ptr mlir::pto::createPTOValidateVMIIRPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createPTOValidateVMILayoutIRPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index 30529514f2..d414e4abcb 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -483,6 +483,16 @@ class VPTOLegalityValidator { return VPTOMaskGranularity::B32; return std::nullopt; } + if (dist == "PK_B64") { + if (width == 32) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + if (dist == "PK4_B32") { + if (width == 8) + return VPTOMaskGranularity::B32; + return std::nullopt; + } if (dist == "MRG4CHN_B8") { if (width == 8) return VPTOMaskGranularity::B32; diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp new file mode 100644 index 0000000000..736bc28924 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -0,0 +1,2017 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutAssignment.cpp - Assign VMI layouts ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTASSIGNMENT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent = 0; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent = 0; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; + +static unsigned getElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static StringRef getMaskGranularityForElement(Type elementType) { + switch (getElementBitWidth(elementType)) { + case 8: + return "b8"; + case 16: + return "b16"; + case 32: + return "b32"; + default: + return ""; + } +} + +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + return std::nullopt; +} + +static bool isLane0SplatShuffle(VMIShuffleOp op) { + auto sourceType = cast(op.getSource().getType()); + ArrayRef indices = op.getIndices(); + return sourceType.getElementCount() == 1 && !indices.empty() && + llvm::all_of(indices, [](int64_t index) { return index == 0; }); +} + +bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +struct LayoutSolver { + explicit LayoutSolver(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities) + : module(module), ctx(module.getContext()), capabilities(capabilities) {} + + unsigned addDataValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = dataIds.try_emplace(value, dataNodes.size()); + if (inserted) + dataNodes.push_back( + DataNode{value, type, it->second, type.getLayoutAttr()}); + return it->second; + } + + unsigned addMaskValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = maskIds.try_emplace(value, maskNodes.size()); + if (inserted) { + std::string granularity; + if (VMIMaskType::isConcreteGranularity(type.getGranularity())) + granularity = type.getGranularity().str(); + maskNodes.push_back( + MaskNode{value, type, it->second, type.getLayoutAttr(), granularity}); + } + return it->second; + } + + unsigned find(unsigned id) { + if (dataNodes[id].parent == id) + return id; + dataNodes[id].parent = find(dataNodes[id].parent); + return dataNodes[id].parent; + } + + unsigned findMask(unsigned id) { + if (maskNodes[id].parent == id) + return id; + maskNodes[id].parent = findMask(maskNodes[id].parent); + return maskNodes[id].parent; + } + + LogicalResult unite(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addDataValue(lhs); + unsigned rhsId = addDataValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = find(lhsId); + unsigned rhsRoot = find(rhsId); + if (lhsRoot == rhsRoot) + return success(); + dataNodes[rhsRoot].parent = lhsRoot; + VMILayoutAttr lhsNatural = dataNodes[lhsRoot].naturalLayout; + VMILayoutAttr rhsNatural = dataNodes[rhsRoot].naturalLayout; + if (lhsNatural && rhsNatural && lhsNatural != rhsNatural) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << lhsNatural << " and " << rhsNatural; + if (!lhsNatural) + dataNodes[lhsRoot].naturalLayout = rhsNatural; + return success(); + } + + LogicalResult uniteMask(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addMaskValue(lhs); + unsigned rhsId = addMaskValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = findMask(lhsId); + unsigned rhsRoot = findMask(rhsId); + if (lhsRoot == rhsRoot) + return success(); + + MaskNode &lhsNode = maskNodes[lhsRoot]; + MaskNode &rhsNode = maskNodes[rhsRoot]; + if (lhsNode.requestedLayout && rhsNode.requestedLayout && + lhsNode.requestedLayout != rhsNode.requestedLayout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << lhsNode.requestedLayout << " and " << rhsNode.requestedLayout; + if (!lhsNode.requestedGranularity.empty() && + !rhsNode.requestedGranularity.empty() && + lhsNode.requestedGranularity != rhsNode.requestedGranularity) + return op->emitError() << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << lhsNode.requestedGranularity << " and " + << rhsNode.requestedGranularity; + + rhsNode.parent = lhsRoot; + if (!lhsNode.requestedLayout) + lhsNode.requestedLayout = rhsNode.requestedLayout; + if (lhsNode.requestedGranularity.empty()) + lhsNode.requestedGranularity = rhsNode.requestedGranularity; + return success(); + } + + LogicalResult setNaturalLayout(Value value, VMILayoutAttr layout, + Operation *op) { + unsigned id = addDataValue(value); + if (id == ~0u || !layout) + return success(); + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << existing << " and " << layout; + dataNodes[root].naturalLayout = layout; + return success(); + } + + VMILayoutAttr getContiguousLayout() { + return VMILayoutAttr::getContiguous(ctx); + } + + VMILayoutAttr getGroupSlotsLayout(int64_t numGroups) { + return VMILayoutAttr::getGroupSlots(ctx, numGroups); + } + + std::optional getVLaneElems(Type elementType) { + FailureOr lanesPerPart = getDataLanesPerPart(elementType); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return std::nullopt; + return *lanesPerPart / 8; + } + + VMILayoutAttr getPreferredGroupSlotsLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->resultLayout; + return getGroupSlotsLayout(numGroups); + } + + VMILayoutAttr getPreferredGroupReduceSourceLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->sourceLayout; + return getContiguousLayout(); + } + + VMILayoutAttr getPreferredGroupSlotLoadLayout(VMIGroupSlotLoadOp op) { + auto type = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + std::optional sourceGroupStride = + getConstantIndexValue(op.getSourceGroupStride()); + if (sourceGroupStride && *sourceGroupStride == 1) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + } + + bool isE2BGroupBroadcastLoadCandidate(VMIVRegType type, Type sourceType, + Value sourceGroupStride, + int64_t numGroups) { + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return false; + int64_t groupSize = type.getElementCount() / numGroups; + if (numGroups % 8 != 0) + return false; + + if (!isa(sourceType)) + return false; + unsigned elementBits = getElementBitWidth(type.getElementType()); + if (elementBits != 16 && elementBits != 32) + return false; + int64_t directGroupSize = 256 / elementBits; + if (groupSize != directGroupSize && groupSize != 2 * directGroupSize) + return false; + std::optional strideValue = + getConstantIndexValue(sourceGroupStride); + if (!strideValue || *strideValue != 1) + return false; + + VMILayoutAttr existing = type.getLayoutAttr(); + if (!existing) + return true; + if (groupSize == directGroupSize) + return existing.isContiguous(); + return existing.isDeinterleaved() && existing.getFactor() == 2 && + existing.getBlockElems() == 1; + } + + bool isE2BGroupBroadcastLoadCandidate(VMIGroupBroadcastLoadOp op) { + return isE2BGroupBroadcastLoadCandidate( + cast(op.getResult().getType()), op.getSource().getType(), + op.getSourceGroupStride(), op.getNumGroupsAttr().getInt()); + } + + VMILayoutAttr getPreferredGroupBroadcastLoadLayout( + VMIGroupBroadcastLoadOp op) { + auto type = cast(op.getResult().getType()); + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + + if (!isE2BGroupBroadcastLoadCandidate(op)) + return {}; + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = type.getElementCount() / numGroups; + int64_t directGroupSize = 256 / getElementBitWidth(type.getElementType()); + if (groupSize == directGroupSize) + return getContiguousLayout(); + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + } + + VMILayoutAttr getPreferredGroupBroadcastSourceLayout(Value value, + int64_t numGroups) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutAttr solved = getDataLayout(value); + if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && + solved.getSlots() > 0) + return solved; + if (auto load = value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(load); + return getPreferredGroupSlotsLayout(type, numGroups); + } + + VMILayoutAttr getPreferredGroupLoadResultLayout(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return getContiguousLayout(); + + if (!type.getElementType().isF32()) + return getContiguousLayout(); + + int64_t groupSize = type.getElementCount() / numGroups; + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return getContiguousLayout(); + + if (groupSize == 16) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + if (groupSize == 32) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + + return getContiguousLayout(); + } + + LogicalResult validateGroupLoadLayoutPlan(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (type.getLayoutAttr()) + return success(); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return success(); + if (!type.getElementType().isF32()) + return success(); + + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize != 16 && groupSize != 32) + return success(); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (rowStride && *rowStride > 0 && *rowStride % 8 == 0) + return success(); + + return op.emitError() + << kVMIDiagLayoutContractPrefix << "pto.vmi.group_load group_size " + << groupSize + << " requires constant positive row_stride divisible by 8 f32 " + "elements for the block8 stride plan; stable gather fallback is " + "not implemented"; + } + + VMILayoutAttr getPreferredGroupStoreUseLayout(Value value, + int64_t numGroups) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutAttr solved = getDataLayout(value); + if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && + solved.getSlots() > 0) + return solved; + if (value.getDefiningOp() || + value.getDefiningOp() || + value.getDefiningOp() || + value.getDefiningOp()) + return getPreferredGroupSlotsLayout(type, numGroups); + if (auto load = value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(load); + return getContiguousLayout(); + } + + VMILayoutAttr getPreferredDenseStoreUseLayout(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + + VMILayoutAttr layout = getExplicitDataLayout(value); + if (!layout || !layout.hasDenseLaneStride()) + layout = type.getLayoutAttr(); + if (!layout || !layout.hasDenseLaneStride()) + return getContiguousLayout(); + + auto candidateType = + VMIVRegType::get(ctx, type.getElementCount(), type.getElementType(), + layout); + VMILayoutSupport supports; + if (succeeded(supports.getContiguousStoreSupport(candidateType))) + return layout; + return getContiguousLayout(); + } + + VMILayoutAttr getDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + unsigned root = find(id); + if (dataNodes[root].naturalLayout) + return dataNodes[root].naturalLayout; + return getContiguousLayout(); + } + + VMILayoutAttr getExplicitDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + return dataNodes[find(id)].naturalLayout; + } + + bool hasCompatibleTruncFUseForGroupReduce(Value value, int64_t groupSize) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType || !sourceType.getElementType().isF32()) + return false; + + for (OpOperand &use : value.getUses()) { + auto truncf = dyn_cast(use.getOwner()); + if (!truncf || use.getOperandNumber() != 0) + continue; + + auto resultType = dyn_cast(truncf.getResult().getType()); + if (!resultType) + continue; + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && resultBits == 16) + return true; + if (vlaneElems && groupSize == 4 * *vlaneElems && resultBits == 8) + return true; + } + return false; + } + + bool isCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact, + VMILayoutAttr layout) { + if (!layout) + return false; + if (fact.kind == VMIGroupReduceLayoutKind::OneVLane || + fact.kind == VMIGroupReduceLayoutKind::RowLocal) + return layout.isContiguous(); + int64_t factor = fact.kind == VMIGroupReduceLayoutKind::TwoVLane ? 2 : 4; + return layout.isDeinterleaved() && layout.getFactor() == factor && + (layout.getBlockElems() == 1 || layout.getBlockElems() == 8); + } + + VMILayoutAttr + getTruncFCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact) { + if (fact.kind == VMIGroupReduceLayoutKind::TwoVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + if (fact.kind == VMIGroupReduceLayoutKind::FourVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + return {}; + } + + LogicalResult requestMask(Value mask, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + unsigned id = addMaskValue(mask); + if (id == ~0u) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask layout or granularity"; + MaskNode &node = maskNodes[findMask(id)]; + if (node.requestedLayout && node.requestedLayout != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << node.requestedLayout << " and " << layout; + if (!node.requestedGranularity.empty() && + node.requestedGranularity != granularity) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " << node.requestedGranularity + << " and " << granularity; + node.requestedLayout = layout; + node.requestedGranularity = granularity.str(); + return success(); + } + + void requestDataUse(OpOperand &operand, VMILayoutAttr layout) { + if (isa(operand.get().getType())) + dataUseRequests.push_back(DataUseRequest{&operand, layout}); + } + + bool canProducerAdoptConsumerLayout(Operation *op) { + if (!op) + return false; + return isa(op); + } + + bool canGroupBroadcastProduceLayout(VMIGroupBroadcastOp broadcast, + VMILayoutAttr resultLayout) { + if (!resultLayout) + return false; + auto sourceType = cast(broadcast.getSource().getType()); + auto resultType = cast(broadcast.getResult().getType()); + int64_t numGroups = broadcast.getNumGroupsAttr().getInt(); + auto assignedSourceType = VMIVRegType::get( + ctx, sourceType.getElementCount(), sourceType.getElementType(), + getPreferredGroupSlotsLayout(sourceType, numGroups)); + auto assignedResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + VMILayoutSupport supports; + return succeeded(supports.getGroupBroadcastSupport( + capabilities, assignedSourceType, assignedResultType, numGroups)); + } + + bool canGroupBroadcastLoadProduceLayout(VMIGroupBroadcastLoadOp load, + VMILayoutAttr resultLayout) { + if (!resultLayout) + return false; + auto resultType = cast(load.getResult().getType()); + int64_t numGroups = load.getNumGroupsAttr().getInt(); + unsigned elementBits = getElementBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return false; + std::optional stride = + getConstantIndexValue(load.getSourceGroupStride()); + int64_t alignedStrideElems = 256 / elementBits; + int64_t slots = 0; + if (stride && *stride == 1) + slots = 8; + else if (stride && *stride > 0 && *stride % alignedStrideElems == 0) + slots = 1; + else + return false; + + auto assignedSourceType = + VMIVRegType::get(ctx, numGroups, resultType.getElementType(), + VMILayoutAttr::getGroupSlots(ctx, numGroups, slots)); + auto assignedResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + VMILayoutSupport supports; + return succeeded(supports.getGroupBroadcastSupport( + capabilities, assignedSourceType, assignedResultType, numGroups)); + } + + bool canEquivalenceClassAdoptConsumerLayout(Value value, + VMILayoutAttr requestedLayout) { + unsigned id = addDataValue(value); + if (id == ~0u) + return true; + unsigned root = find(id); + for (DataNode &node : dataNodes) { + if (find(dataIds.lookup(node.value)) != root) + continue; + if (auto broadcast = node.value.getDefiningOp()) { + if (node.value == broadcast.getResult() && + !canGroupBroadcastProduceLayout(broadcast, requestedLayout)) + return false; + } + if (auto load = node.value.getDefiningOp()) { + if (node.value == load.getResult() && + !canGroupBroadcastLoadProduceLayout(load, requestedLayout)) + return false; + } + } + return true; + } + + bool isUnsupportedGroupBroadcastResultForLayout(Value value, + VMILayoutAttr layout) { + auto broadcast = value.getDefiningOp(); + if (broadcast) + return !canGroupBroadcastProduceLayout(broadcast, layout); + auto load = value.getDefiningOp(); + return load && !canGroupBroadcastLoadProduceLayout(load, layout); + } + + LogicalResult constrainElementwiseBinary(OpOperand &lhs, OpOperand &rhs, + Value result, Operation *op) { + VMILayoutAttr lhsLayout = getExplicitDataLayout(lhs.get()); + VMILayoutAttr rhsLayout = getExplicitDataLayout(rhs.get()); + VMILayoutAttr fallback = getContiguousLayout(); + if ((lhsLayout && + isUnsupportedGroupBroadcastResultForLayout(rhs.get(), lhsLayout)) || + (rhsLayout && + isUnsupportedGroupBroadcastResultForLayout(lhs.get(), rhsLayout))) { + requestDataUse(lhs, fallback); + requestDataUse(rhs, fallback); + return setNaturalLayout(result, fallback, op); + } + + if (failed(unite(lhs.get(), rhs.get(), op))) + return failure(); + return unite(lhs.get(), result, op); + } + + bool canAdoptConsumerRequestedLayout(Value value, + VMILayoutAttr requestedLayout) { + Operation *definingOp = value.getDefiningOp(); + if (!definingOp) + return false; + if (isa(definingOp)) { + if (requestedLayout && requestedLayout.hasDenseLaneStride()) { + auto type = dyn_cast(value.getType()); + if (!type) + return false; + auto candidateType = + VMIVRegType::get(ctx, type.getElementCount(), type.getElementType(), + requestedLayout); + VMILayoutSupport supports; + if (failed(supports.getContiguousLoadSupport(candidateType))) + return false; + } + } else { + if (!requestedLayout || requestedLayout.isContiguous()) + return false; + if (!canProducerAdoptConsumerLayout(definingOp)) + return false; + } + if (!canEquivalenceClassAdoptConsumerLayout(value, requestedLayout)) + return false; + if (value.hasOneUse()) + return true; + + unsigned matchingRequests = 0; + unsigned totalUses = 0; + for (OpOperand &use : value.getUses()) { + ++totalUses; + bool foundRequest = false; + for (DataUseRequest request : dataUseRequests) { + if (request.operand != &use) + continue; + if (request.layout != requestedLayout) + return false; + foundRequest = true; + } + if (!foundRequest) + return false; + ++matchingRequests; + } + return matchingRequests == totalUses; + } + + LogicalResult applyConsumerDrivenDataLayouts() { + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + if (!canAdoptConsumerRequestedLayout(value, request.layout)) + continue; + unsigned id = addDataValue(value); + if (id == ~0u) + continue; + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != request.layout) + continue; + dataNodes[root].naturalLayout = request.layout; + } + return success(); + } + + LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + if (!isa(operand.get().getType())) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask use layout or granularity"; + maskUseRequests.push_back( + MaskUseRequest{&operand, layout, granularity.str()}); + return success(); + } + + LogicalResult collect() { + module.walk([&](Operation *op) { + for (Value result : op->getResults()) { + addDataValue(result); + addMaskValue(result); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (BlockArgument arg : block.getArguments()) { + addDataValue(arg); + addMaskValue(arg); + } + }); + return success(); + } + + LogicalResult addConstraints() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto maskAnd = dyn_cast(op)) { + if (failed(uniteMask(maskAnd.getLhs(), maskAnd.getRhs(), op)) || + failed(uniteMask(maskAnd.getLhs(), maskAnd.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskOr = dyn_cast(op)) { + if (failed(uniteMask(maskOr.getLhs(), maskOr.getRhs(), op)) || + failed(uniteMask(maskOr.getLhs(), maskOr.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskXor = dyn_cast(op)) { + if (failed(uniteMask(maskXor.getLhs(), maskXor.getRhs(), op)) || + failed(uniteMask(maskXor.getLhs(), maskXor.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskNot = dyn_cast(op)) { + if (failed(uniteMask(maskNot.getSource(), maskNot.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(addf.getLhsMutable(), + addf.getRhsMutable(), + addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addi = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(addi.getLhsMutable(), + addi.getRhsMutable(), + addi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(subf.getLhsMutable(), + subf.getRhsMutable(), + subf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subi = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(subi.getLhsMutable(), + subi.getRhsMutable(), + subi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto mulf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(mulf.getLhsMutable(), + mulf.getRhsMutable(), + mulf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto muli = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(muli.getLhsMutable(), + muli.getRhsMutable(), + muli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto fma = dyn_cast(op)) { + if (failed(unite(fma.getLhs(), fma.getRhs(), op)) || + failed(unite(fma.getLhs(), fma.getAcc(), op)) || + failed(unite(fma.getLhs(), fma.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto divf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(divf.getLhsMutable(), + divf.getRhsMutable(), + divf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto minf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(minf.getLhsMutable(), + minf.getRhsMutable(), + minf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maxf = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(maxf.getLhsMutable(), + maxf.getRhsMutable(), + maxf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto negf = dyn_cast(op)) { + if (failed(unite(negf.getSource(), negf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absf = dyn_cast(op)) { + if (failed(unite(absf.getSource(), absf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absi = dyn_cast(op)) { + if (failed(unite(absi.getSource(), absi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto sqrt = dyn_cast(op)) { + if (failed(unite(sqrt.getSource(), sqrt.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto exp = dyn_cast(op)) { + if (failed(unite(exp.getSource(), exp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ln = dyn_cast(op)) { + if (failed(unite(ln.getSource(), ln.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto relu = dyn_cast(op)) { + if (failed(unite(relu.getSource(), relu.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto fptosi = dyn_cast(op)) { + if (failed(unite(fptosi.getSource(), fptosi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto sitofp = dyn_cast(op)) { + if (failed(unite(sitofp.getSource(), sitofp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto andi = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(andi.getLhsMutable(), + andi.getRhsMutable(), + andi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ori = dyn_cast(op)) { + if (failed(constrainElementwiseBinary( + ori.getLhsMutable(), ori.getRhsMutable(), ori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto xori = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(xori.getLhsMutable(), + xori.getRhsMutable(), + xori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shli = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(shli.getLhsMutable(), + shli.getRhsMutable(), + shli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shrui = dyn_cast(op)) { + if (failed(constrainElementwiseBinary(shrui.getLhsMutable(), + shrui.getRhsMutable(), + shrui.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto notOp = dyn_cast(op)) { + if (failed(unite(notOp.getSource(), notOp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpf = dyn_cast(op)) { + if (failed(unite(cmpf.getLhs(), cmpf.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + if (failed(unite(cmpi.getLhs(), cmpi.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + if (failed(unite(select.getTrueValue(), select.getFalseValue(), op)) || + failed(unite(select.getTrueValue(), select.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + if (failed(setNaturalLayout(activePrefix.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + requestDataUse(compress.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(compress.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) + sourceLayout = solvedSourceLayout; + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) + sourceLayout = solvedSourceLayout; + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto broadcast = dyn_cast(op)) { + requestDataUse( + broadcast.getSourceMutable(), + getPreferredGroupBroadcastSourceLayout( + broadcast.getSource(), broadcast.getNumGroupsAttr().getInt())); + return WalkResult::advance(); + } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto extf = dyn_cast(op)) { + auto sourceType = cast(extf.getSource().getType()); + auto resultType = cast(extf.getResult().getType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extf.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extf.getResult(), fact->resultLayout, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto extsi = dyn_cast(op)) { + auto sourceType = cast(extsi.getSource().getType()); + auto resultType = cast(extsi.getResult().getType()); + VMILayoutAttr sourceLayout = getDataLayout(extsi.getSource()); + if (sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 8 && + getElementBitWidth(sourceType.getElementType()) < 32 && + getElementBitWidth(resultType.getElementType()) == 32) { + requestDataUse(extsi.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(extsi.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extsi.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extsi.getResult(), fact->resultLayout, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto extui = dyn_cast(op)) { + auto sourceType = cast(extui.getSource().getType()); + auto resultType = cast(extui.getResult().getType()); + VMILayoutAttr sourceLayout = getDataLayout(extui.getSource()); + if (sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 8 && + getElementBitWidth(sourceType.getElementType()) < 32 && + getElementBitWidth(resultType.getElementType()) == 32) { + requestDataUse(extui.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(extui.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extui.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extui.getResult(), fact->resultLayout, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto truncf = dyn_cast(op)) { + auto sourceType = cast(truncf.getSource().getType()); + auto resultType = cast(truncf.getResult().getType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); + if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 1) { + requestDataUse(truncf.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutAttr resultLayout = getContiguousLayout(); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { + resultLayout = fact->resultLayout; + requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); + } + if (failed(setNaturalLayout(truncf.getResult(), resultLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto trunci = dyn_cast(op)) { + auto sourceType = cast(trunci.getSource().getType()); + auto resultType = cast(trunci.getResult().getType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); + if (succeeded(fact) && sourceLayout && sourceLayout.isGroupSlots() && + (sourceLayout.getSlots() == 1 || sourceLayout.getSlots() == 8) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { + requestDataUse(trunci.getSourceMutable(), sourceLayout); + VMILayoutAttr resultLayout = sourceLayout; + if (sourceLayout.getSlots() == 8 && + fact->kind == VMICastLayoutKind::Narrow4x) + resultLayout = VMILayoutAttr::getGroupSlots( + ctx, sourceLayout.getNumGroups(), sourceLayout.getSlots(), + /*laneStride=*/4); + if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutAttr resultLayout = getContiguousLayout(); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { + resultLayout = fact->resultLayout; + requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); + } + if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto bitcast = dyn_cast(op)) { + if (failed(unite(bitcast.getSource(), bitcast.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout(load.getLow(), getContiguousLayout(), op)) || + failed(setNaturalLayout(load.getHigh(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto gather = dyn_cast(op)) { + auto resultType = cast(gather.getResult().getType()); + requestDataUse(gather.getIndicesMutable(), getContiguousLayout()); + requestDataUse(gather.getPassthruMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + gather.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(gather.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + if (failed(validateGroupLoadLayoutPlan(load))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + load.getResult(), getPreferredGroupLoadResultLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout( + load.getResult(), getPreferredGroupSlotLoadLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout( + load.getResult(), + getPreferredGroupBroadcastLoadLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + if (failed(requestMaskUse( + load.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getValueMutable(), + getPreferredDenseStoreUseLayout(store.getValue())); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getLowMutable(), getContiguousLayout()); + requestDataUse(store.getHighMutable(), getContiguousLayout()); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + requestDataUse( + store.getValueMutable(), + getPreferredGroupStoreUseLayout(store.getValue(), + store.getNumGroupsAttr().getInt())); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto scatter = dyn_cast(op)) { + auto valueType = cast(scatter.getValue().getType()); + requestDataUse(scatter.getValueMutable(), getContiguousLayout()); + requestDataUse(scatter.getIndicesMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + scatter.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_split", channels); + if (!capability.isSupported()) { + split.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + requestDataUse(split.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, channels)); + for (Value result : split.getResults()) + if (failed(setNaturalLayout(result, getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_merge", channels); + if (!capability.isSupported()) { + merge.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + for (OpOperand &input : merge.getInputsMutable()) + requestDataUse(input, getContiguousLayout()); + if (failed(setNaturalLayout( + merge.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, channels), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shuffle = dyn_cast(op)) { + auto sourceType = cast(shuffle.getSource().getType()); + auto resultType = cast(shuffle.getResult().getType()); + if (sourceType.hasLayout() || resultType.hasLayout()) + return WalkResult::advance(); + + requestDataUse(shuffle.getSourceMutable(), getContiguousLayout()); + if (isLane0SplatShuffle(shuffle)) + return WalkResult::advance(); + if (failed(setNaturalLayout(shuffle.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ifOp = dyn_cast(op)) { + if (failed(addIfConstraints(ifOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto executeRegionOp = dyn_cast(op)) { + if (failed(addExecuteRegionConstraints(executeRegionOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto indexSwitchOp = dyn_cast(op)) { + if (failed(addIndexSwitchConstraints(indexSwitchOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto whileOp = dyn_cast(op)) { + if (failed(addWhileConstraints(whileOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto forOp = dyn_cast(op)) { + if (failed(addForConstraints(forOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto branchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(branchOp.getDest(), + branchOp.getDestOperands(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto condBranchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(condBranchOp.getTrueDest(), + condBranchOp.getTrueDestOperands(), + op)) || + failed(addBranchConstraints(condBranchOp.getFalseDest(), + condBranchOp.getFalseDestOperands(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto switchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(switchOp.getDefaultDestination(), + switchOp.getDefaultOperands(), op))) + return WalkResult::interrupt(); + for (auto [dest, operands] : llvm::zip(switchOp.getCaseDestinations(), + switchOp.getCaseOperands())) { + if (failed(addBranchConstraints(dest, operands, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto returnOp = dyn_cast(op)) { + if (failed(addReturnConstraints(returnOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto callOp = dyn_cast(op)) { + if (failed(addCallConstraints(callOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (op->getName().getStringRef() == "func.call_indirect") { + if (hasVMIValueTypes(op)) { + op->emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto funcOp = dyn_cast(op)) { + if (funcOp.empty() && hasVMIFunctionType(funcOp)) { + funcOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed function declaration requires an explicit " + "external ABI materialization plan"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult uniteEquivalentValues(Value lhs, Value rhs, Operation *op) { + if (failed(unite(lhs, rhs, op))) + return failure(); + return uniteMask(lhs, rhs, op); + } + + LogicalResult addIfConstraints(scf::IfOp ifOp) { + for (OpResult result : ifOp->getResults()) { + unsigned resultNo = result.getResultNumber(); + for (Region *region : {&ifOp.getThenRegion(), &ifOp.getElseRegion()}) { + if (region->empty()) + continue; + auto yieldOp = dyn_cast(region->front().getTerminator()); + if (!yieldOp || resultNo >= yieldOp.getNumOperands()) + continue; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(resultNo), + ifOp))) + return failure(); + } + } + return success(); + } + + LogicalResult addYieldConstraints(ResultRange results, scf::YieldOp yieldOp, + Operation *op) { + for (auto [index, result] : llvm::enumerate(results)) { + if (index >= yieldOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addExecuteRegionConstraints(scf::ExecuteRegionOp executeOp) { + WalkResult result = executeOp.getRegion().walk([&](scf::YieldOp yieldOp) { + if (yieldOp->getParentOp() != executeOp.getOperation()) + return WalkResult::advance(); + if (failed( + addYieldConstraints(executeOp->getResults(), yieldOp, executeOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult addIndexSwitchConstraints(scf::IndexSwitchOp indexSwitchOp) { + auto addBlockTerminator = [&](Block &block) -> LogicalResult { + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp) + return success(); + return addYieldConstraints(indexSwitchOp->getResults(), yieldOp, + indexSwitchOp); + }; + + if (failed(addBlockTerminator(indexSwitchOp.getDefaultBlock()))) + return failure(); + for (unsigned idx = 0, e = indexSwitchOp.getNumCases(); idx < e; ++idx) + if (failed(addBlockTerminator(indexSwitchOp.getCaseBlock(idx)))) + return failure(); + return success(); + } + + LogicalResult addWhileConstraints(scf::WhileOp whileOp) { + auto inits = whileOp.getInits(); + auto beforeArgs = whileOp.getBeforeArguments(); + Block &afterBlock = whileOp.getAfter().front(); + auto conditionOp = + dyn_cast(whileOp.getBefore().front().getTerminator()); + auto yieldOp = dyn_cast(afterBlock.getTerminator()); + + for (auto [index, init] : llvm::enumerate(inits)) { + Value anchor = init; + if (index < beforeArgs.size() && + failed(uniteEquivalentValues(anchor, beforeArgs[index], whileOp))) + return failure(); + if (conditionOp && index < conditionOp.getArgs().size() && + failed(uniteEquivalentValues(anchor, conditionOp.getArgs()[index], + whileOp))) + return failure(); + if (index < afterBlock.getNumArguments() && + failed(uniteEquivalentValues(anchor, afterBlock.getArgument(index), + whileOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), + whileOp))) + return failure(); + if (index < whileOp.getNumResults() && + failed( + uniteEquivalentValues(anchor, whileOp.getResult(index), whileOp))) + return failure(); + } + return success(); + } + + LogicalResult addForConstraints(scf::ForOp forOp) { + auto initArgs = forOp.getInitArgs(); + auto regionIterArgs = forOp.getRegionIterArgs(); + auto results = forOp.getResults(); + scf::YieldOp yieldOp = nullptr; + if (Block *body = forOp.getBody()) + yieldOp = dyn_cast(body->getTerminator()); + + for (auto [index, initArg] : llvm::enumerate(initArgs)) { + Value anchor = initArg; + if (index < regionIterArgs.size() && + failed(uniteEquivalentValues(anchor, regionIterArgs[index], forOp))) + return failure(); + if (index < results.size() && + failed(uniteEquivalentValues(anchor, results[index], forOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed( + uniteEquivalentValues(anchor, yieldOp.getOperand(index), forOp))) + return failure(); + } + return success(); + } + + LogicalResult addBranchConstraints(Block *dest, OperandRange operands, + Operation *op) { + if (!dest) + return success(); + for (auto [index, operand] : llvm::enumerate(operands)) { + if (index >= dest->getNumArguments()) + break; + if (failed(uniteEquivalentValues(operand, dest->getArgument(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addReturnConstraints(func::ReturnOp returnOp) { + auto func = returnOp->getParentOfType(); + if (!func) + return success(); + + auto it = firstReturnOperandsByFunc.find(func); + if (it == firstReturnOperandsByFunc.end()) { + SmallVector operands(returnOp.getOperands()); + firstReturnOperandsByFunc.try_emplace(func, std::move(operands)); + return success(); + } + + ArrayRef firstOperands = it->second; + for (auto [index, operand] : llvm::enumerate(returnOp.getOperands())) { + if (index >= firstOperands.size()) + break; + if (failed( + uniteEquivalentValues(firstOperands[index], operand, returnOp))) + return failure(); + } + return success(); + } + + bool hasVMIValueTypes(Operation *op) { + return llvm::any_of(op->getOperandTypes(), containsVMIType) || + llvm::any_of(op->getResultTypes(), containsVMIType); + } + + bool hasVMIFunctionType(func::FuncOp func) { + FunctionType type = func.getFunctionType(); + return llvm::any_of(type.getInputs(), containsVMIType) || + llvm::any_of(type.getResults(), containsVMIType); + } + + LogicalResult addCallConstraints(func::CallOp callOp) { + if (!hasVMIValueTypes(callOp)) + return success(); + + auto callee = SymbolTable::lookupNearestSymbolFrom( + callOp, callOp.getCalleeAttr()); + if (!callee || callee.empty()) + return callOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + + for (auto [operand, argument] : + llvm::zip(callOp.getOperands(), callee.getArguments())) { + if (failed(uniteEquivalentValues(operand, argument, callOp))) + return failure(); + } + + SmallVector returns; + callee.walk([&](func::ReturnOp returnOp) { returns.push_back(returnOp); }); + for (func::ReturnOp returnOp : returns) { + for (auto [index, result] : llvm::enumerate(callOp.getResults())) { + if (index >= returnOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, returnOp.getOperand(index), + callOp))) + return failure(); + } + } + return success(); + } + + void rewriteDataTypes() { + for (DataNode &node : dataNodes) { + VMILayoutAttr layout = getDataLayout(node.value); + node.value.setType(VMIVRegType::get(ctx, node.type.getElementCount(), + node.type.getElementType(), layout)); + } + } + + LogicalResult insertDataUseMaterializations() { + OpBuilder builder(ctx); + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "data use materialization requires layout-assigned source " + "type"; + if (sourceLayout == request.layout) + continue; + + auto resultType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), request.layout); + builder.setInsertionPoint(request.operand->getOwner()); + auto ensure = builder.create( + request.operand->getOwner()->getLoc(), resultType, value); + request.operand->set(ensure.getResult()); + } + return success(); + } + + LogicalResult inferMaskRequests() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto cmpf = dyn_cast(op)) { + auto lhsType = cast(cmpf.getLhs().getType()); + if (failed(requestMask( + cmpf.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + auto lhsType = cast(cmpi.getLhs().getType()); + if (failed(requestMask( + cmpi.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + auto resultType = cast(select.getResult().getType()); + if (failed(requestMaskUse( + select.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + auto resultType = cast(activePrefix.getResult().getType()); + if (failed(requestMaskUse( + activePrefix.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + auto resultType = cast(compress.getResult().getType()); + if (failed(requestMaskUse( + compress.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + void rewriteMaskTypes() { + for (MaskNode &node : maskNodes) { + MaskNode &root = maskNodes[findMask(maskIds.lookup(node.value))]; + VMILayoutAttr layout = + root.requestedLayout ? root.requestedLayout : getContiguousLayout(); + StringRef granularity = root.requestedGranularity.empty() + ? StringRef("b32") + : StringRef(root.requestedGranularity); + node.value.setType(VMIMaskType::get(ctx, node.type.getElementCount(), + granularity, layout)); + } + } + + LogicalResult insertMaskUseMaterializations() { + OpBuilder builder(ctx); + for (MaskUseRequest request : maskUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "mask use materialization requires layout-assigned source " + "type"; + + builder.setInsertionPoint(request.operand->getOwner()); + Value current = value; + VMIMaskType currentType = sourceType; + if (sourceLayout != request.layout) { + auto layoutType = + VMIMaskType::get(ctx, currentType.getElementCount(), + currentType.getGranularity(), request.layout); + auto ensureLayout = builder.create( + request.operand->getOwner()->getLoc(), layoutType, current); + current = ensureLayout.getResult(); + currentType = layoutType; + } + + if (currentType.getGranularity() != request.granularity) { + auto granularityType = + VMIMaskType::get(ctx, currentType.getElementCount(), + request.granularity, request.layout); + auto ensureGranularity = builder.create( + request.operand->getOwner()->getLoc(), granularityType, current); + current = ensureGranularity.getResult(); + } + + if (current != value) + request.operand->set(current); + } + return success(); + } + + void rewriteFunctionType() { + module.walk([&](func::FuncOp func) { + if (func.empty()) + return; + + SmallVector inputs; + inputs.reserve(func.getNumArguments()); + for (BlockArgument arg : func.getArguments()) + inputs.push_back(arg.getType()); + + SmallVector results; + auto it = firstReturnOperandsByFunc.find(func); + if (it != firstReturnOperandsByFunc.end()) { + for (Value operand : it->second) + results.push_back(operand.getType()); + } else { + FunctionType functionType = func.getFunctionType(); + for (Type type : functionType.getResults()) { + if (auto vregType = dyn_cast(type)) { + results.push_back(VMIVRegType::get(ctx, vregType.getElementCount(), + vregType.getElementType(), + getContiguousLayout())); + } else if (auto maskType = dyn_cast(type)) { + results.push_back(VMIMaskType::get(ctx, maskType.getElementCount(), + "b32", getContiguousLayout())); + } else { + results.push_back(type); + } + } + } + + func.setFunctionType(FunctionType::get(ctx, inputs, results)); + }); + } + + LogicalResult run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + if (failed(applyConsumerDrivenDataLayouts())) + return failure(); + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module, /*diagOS=*/nullptr, + /*verifyHelperSupport=*/false); + } + + ModuleOp module; + MLIRContext *ctx; + const VMITargetCapabilityRegistry &capabilities; + DenseMap dataIds; + DenseMap maskIds; + DenseMap> firstReturnOperandsByFunc; + SmallVector dataNodes; + SmallVector maskNodes; + SmallVector dataUseRequests; + SmallVector maskUseRequests; +}; + +struct VMILayoutAssignmentPass + : public mlir::pto::impl::VMILayoutAssignmentBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutAssignmentPass) + + void runOnOperation() override { + VMITargetCapabilityRegistry capabilities; + if (failed(LayoutSolver(getOperation(), capabilities).run())) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutAssignmentPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutFold.cpp b/lib/PTO/Transforms/VMILayoutFold.cpp new file mode 100644 index 0000000000..253ab7c3fc --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutFold.cpp @@ -0,0 +1,231 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutFold.cpp - Fold VMI layout materializations --------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTFOLD +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool hasSameDataShapeAndElementType(VMIVRegType lhs, VMIVRegType rhs) { + return lhs && rhs && lhs.getElementCount() == rhs.getElementCount() && + lhs.getElementType() == rhs.getElementType(); +} + +static bool isLoadProducerLayout(VMIVRegType type) { + if (!type) + return false; + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return false; + if (layout.isContiguous() && layout.getLaneStride() == 1) + return true; + if (layout.isContiguous() && layout.getLaneStride() == 2) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8 || elementBits == 16 || elementBits == 32; + } + if (layout.isContiguous() && layout.getLaneStride() == 4) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8; + } + if (!layout.isDeinterleaved() || layout.getBlockElems() != 1 || + layout.getLaneStride() != 1 || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return false; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool isFoldableLoadEnsure(VMIEnsureLayoutOp ensure) { + auto load = ensure.getSource().getDefiningOp(); + if (!load) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!hasSameDataShapeAndElementType(sourceType, resultType)) + return false; + + return isLoadProducerLayout(resultType); +} + +static void tryFoldLoadEnsures( + VMILoadOp load, SmallVectorImpl &maybeDeadEnsures) { + auto sourceType = dyn_cast(load.getResult().getType()); + if (!sourceType) + return; + + VMIVRegType targetType; + SmallVector ensures; + for (OpOperand &use : load.getResult().getUses()) { + auto ensure = dyn_cast(use.getOwner()); + if (!ensure || use.getOperandNumber() != 0 || !isFoldableLoadEnsure(ensure)) + return; + + auto resultType = cast(ensure.getResult().getType()); + if (!targetType) { + targetType = resultType; + } else if (targetType != resultType) { + return; + } + ensures.push_back(ensure); + } + + if (ensures.empty() || targetType == sourceType) + return; + + load.getResult().setType(targetType); + for (VMIEnsureLayoutOp ensure : ensures) { + ensure.getResult().replaceAllUsesWith(load.getResult()); + maybeDeadEnsures.push_back(ensure); + } +} + +static void +tryFoldNestedEnsureLayout(VMIEnsureLayoutOp ensure, + SmallVectorImpl &maybeDeadEnsures) { + auto inner = ensure.getSource().getDefiningOp(); + if (!inner) + return; + + if (inner.getSource().getType() != ensure.getResult().getType()) + return; + + ensure.getResult().replaceAllUsesWith(inner.getSource()); + maybeDeadEnsures.push_back(ensure); + maybeDeadEnsures.push_back(inner); +} + +static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !resultType) + return false; + + VMILayoutSupport supports; + return succeeded( + supports.canFoldContiguousStoreMaterialization(sourceType, resultType)); +} + +static void tryFoldEnsureLayoutIntoOperand( + OpOperand &operand, SmallVectorImpl &maybeDeadEnsures) { + auto ensure = operand.get().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + + operand.set(ensure.getSource()); + maybeDeadEnsures.push_back(ensure); +} + +static void tryFoldEnsureLayoutIntoMaskedStore( + VMIMaskedStoreOp store, + SmallVectorImpl &maybeDeadEnsures, + SmallVectorImpl &maybeDeadMaskEnsures) { + auto ensure = store.getValue().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + auto maskEnsure = store.getMask().getDefiningOp(); + if (!maskEnsure) + return; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!sourceType || !maskSourceType || !maskResultType) + return; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskSourceLayout = maskSourceType.getLayoutAttr(); + VMILayoutAttr maskResultLayout = maskResultType.getLayoutAttr(); + if (!sourceLayout || !maskSourceLayout || !maskResultLayout) + return; + if (sourceLayout != maskSourceLayout || !maskResultLayout.isContiguous()) + return; + + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return; + VMILayoutSupport supports; + if (failed(supports.canFoldContiguousMaskedStoreMaterialization( + sourceType, maskSourceType, resultType, maskResultType))) + return; + + store.getValueMutable().set(ensure.getSource()); + store.getMaskMutable().set(maskEnsure.getSource()); + maybeDeadEnsures.push_back(ensure); + maybeDeadMaskEnsures.push_back(maskEnsure); +} + +struct VMILayoutFoldPass + : public mlir::pto::impl::VMILayoutFoldBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutFoldPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector maybeDeadEnsures; + SmallVector maybeDeadMaskEnsures; + + module.walk([&](VMILoadOp load) { + tryFoldLoadEnsures(load, maybeDeadEnsures); + }); + + module.walk([&](VMIEnsureLayoutOp ensure) { + tryFoldNestedEnsureLayout(ensure, maybeDeadEnsures); + }); + + module.walk([&](Operation *op) { + if (auto store = dyn_cast(op)) + tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), + maybeDeadEnsures); + if (auto maskedStore = dyn_cast(op)) + tryFoldEnsureLayoutIntoMaskedStore(maskedStore, maybeDeadEnsures, + maybeDeadMaskEnsures); + }); + + for (VMIEnsureMaskLayoutOp ensure : llvm::reverse(maybeDeadMaskEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + for (VMIEnsureLayoutOp ensure : llvm::reverse(maybeDeadEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutFoldPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutRematerialize.cpp b/lib/PTO/Transforms/VMILayoutRematerialize.cpp new file mode 100644 index 0000000000..be4842ad5c --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -0,0 +1,400 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutRematerialize.cpp - Rematerialize VMI producers ----------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTREMATERIALIZE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool hasConcreteLayout(VMIVRegType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static bool hasConcreteLayout(VMIMaskType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static Value materializeDataLayout(Value value, VMIVRegType resultType, + Location loc, OpBuilder &builder) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType || sourceType == resultType) + return value; + + return builder.create(loc, resultType, value).getResult(); +} + +template +static std::optional rematerializeWidenExt(ExtOp op, + VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType || !hasConcreteLayout(resultType)) + return std::nullopt; + + VMILayoutSupport supports; + FailureOr sourceLayout = + supports.getWidenSourceLayoutForResultLayout( + sourceType, resultType, resultType.getLayoutAttr()); + if (failed(sourceLayout)) + return std::nullopt; + + auto rematSourceType = + VMIVRegType::get(sourceType.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), *sourceLayout); + Value rematSource = materializeDataLayout(op.getSource(), rematSourceType, + loc, builder); + return builder.create(loc, resultType, rematSource).getResult(); +} + +static std::optional +rematerializeBinaryDataOp(Operation *op, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto rebuild = [&](auto typedOp) -> std::optional { + auto lhsType = dyn_cast(typedOp.getLhs().getType()); + auto rhsType = dyn_cast(typedOp.getRhs().getType()); + if (!lhsType || !rhsType) + return std::nullopt; + auto lhsResultType = + VMIVRegType::get(lhsType.getContext(), lhsType.getElementCount(), + lhsType.getElementType(), resultType.getLayoutAttr()); + auto rhsResultType = + VMIVRegType::get(rhsType.getContext(), rhsType.getElementCount(), + rhsType.getElementType(), resultType.getLayoutAttr()); + Value lhs = + materializeDataLayout(typedOp.getLhs(), lhsResultType, loc, builder); + Value rhs = + materializeDataLayout(typedOp.getRhs(), rhsResultType, loc, builder); + return builder + .create>(loc, resultType, lhs, rhs) + .getResult(); + }; + + if (auto addf = dyn_cast(op)) + return rebuild(addf); + if (auto addi = dyn_cast(op)) + return rebuild(addi); + if (auto subf = dyn_cast(op)) + return rebuild(subf); + if (auto subi = dyn_cast(op)) + return rebuild(subi); + if (auto mulf = dyn_cast(op)) + return rebuild(mulf); + if (auto muli = dyn_cast(op)) + return rebuild(muli); + if (auto divf = dyn_cast(op)) + return rebuild(divf); + if (auto minf = dyn_cast(op)) + return rebuild(minf); + if (auto maxf = dyn_cast(op)) + return rebuild(maxf); + if (auto andi = dyn_cast(op)) + return rebuild(andi); + if (auto ori = dyn_cast(op)) + return rebuild(ori); + if (auto xori = dyn_cast(op)) + return rebuild(xori); + if (auto shli = dyn_cast(op)) + return rebuild(shli); + if (auto shrui = dyn_cast(op)) + return rebuild(shrui); + return std::nullopt; +} + +static std::optional +rematerializeUnaryDataOp(Operation *op, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto rebuild = [&](auto typedOp) -> std::optional { + auto sourceType = dyn_cast(typedOp.getSource().getType()); + if (!sourceType) + return std::nullopt; + auto sourceResultType = VMIVRegType::get( + sourceType.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), resultType.getLayoutAttr()); + Value source = materializeDataLayout(typedOp.getSource(), sourceResultType, + loc, builder); + return builder + .create>(loc, resultType, source) + .getResult(); + }; + + if (auto negf = dyn_cast(op)) + return rebuild(negf); + if (auto absf = dyn_cast(op)) + return rebuild(absf); + if (auto absi = dyn_cast(op)) + return rebuild(absi); + if (auto sqrt = dyn_cast(op)) + return rebuild(sqrt); + if (auto exp = dyn_cast(op)) + return rebuild(exp); + if (auto ln = dyn_cast(op)) + return rebuild(ln); + if (auto relu = dyn_cast(op)) + return rebuild(relu); + if (auto notOp = dyn_cast(op)) + return rebuild(notOp); + return std::nullopt; +} + +static std::optional +rematerializeFma(VMIFmaOp fma, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto lhsType = dyn_cast(fma.getLhs().getType()); + auto rhsType = dyn_cast(fma.getRhs().getType()); + auto accType = dyn_cast(fma.getAcc().getType()); + if (!lhsType || !rhsType || !accType) + return std::nullopt; + auto makeType = [&](VMIVRegType type) { + return VMIVRegType::get(type.getContext(), type.getElementCount(), + type.getElementType(), resultType.getLayoutAttr()); + }; + Value lhs = materializeDataLayout(fma.getLhs(), makeType(lhsType), loc, + builder); + Value rhs = materializeDataLayout(fma.getRhs(), makeType(rhsType), loc, + builder); + Value acc = materializeDataLayout(fma.getAcc(), makeType(accType), loc, + builder); + return builder.create(loc, resultType, lhs, rhs, acc).getResult(); +} + +static std::optional rematerializeDataProducer(Value value, + VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto extf = value.getDefiningOp()) + return rematerializeWidenExt(extf, resultType, loc, builder); + if (auto extsi = value.getDefiningOp()) + return rematerializeWidenExt(extsi, resultType, loc, builder); + if (auto extui = value.getDefiningOp()) + return rematerializeWidenExt(extui, resultType, loc, builder); + + if (Operation *op = value.getDefiningOp()) { + if (auto fma = dyn_cast(op)) + return rematerializeFma(fma, resultType, loc, builder); + if (auto result = rematerializeBinaryDataOp(op, resultType, loc, builder)) + return result; + if (auto result = rematerializeUnaryDataOp(op, resultType, loc, builder)) + return result; + } + + if (auto constant = value.getDefiningOp()) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && denseAttr.isSplat()) + return builder + .create(loc, resultType, constant.getValue()) + .getResult(); + } + + if (auto broadcast = value.getDefiningOp()) + return builder.create(loc, resultType, + broadcast.getValue()) + .getResult(); + + if (auto iota = value.getDefiningOp()) + return builder + .create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) + .getResult(); + + return std::nullopt; +} + +static std::optional rematerializeMaskProducer(Value value, + VMIMaskType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto createMask = value.getDefiningOp()) + return builder + .create(loc, resultType, createMask.getActiveLanes()) + .getResult(); + + if (auto createGroupMask = value.getDefiningOp()) + return builder + .create( + loc, resultType, createGroupMask.getActiveElemsPerGroup(), + createGroupMask.getNumGroupsAttr(), createGroupMask.getGroupSizeAttr()) + .getResult(); + + if (auto constantMask = value.getDefiningOp()) + return builder + .create(loc, resultType, + constantMask.getValueAttr()) + .getResult(); + + return std::nullopt; +} + +static bool tryReplaceDataEnsure(VMIEnsureLayoutOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeDataProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +static bool tryRematerializeTruncIThroughSourceEnsure(VMITruncIOp trunc) { + auto resultType = dyn_cast(trunc.getResult().getType()); + if (!resultType || !hasConcreteLayout(resultType)) + return false; + + auto ensure = trunc.getSource().getDefiningOp(); + if (!ensure) + return false; + + auto originalSourceType = dyn_cast(ensure.getSource().getType()); + if (!originalSourceType || !hasConcreteLayout(originalSourceType)) + return false; + VMILayoutAttr originalSourceLayout = originalSourceType.getLayoutAttr(); + if (!originalSourceLayout.isDeinterleaved() || + originalSourceLayout.getBlockElems() != 1) + return false; + + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(originalSourceType, resultType); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return false; + if (originalSourceLayout.getFactor() % fact->factor != 0) + return false; + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return false; + + int64_t rematResultFactor = originalSourceLayout.getFactor() / fact->factor; + VMILayoutAttr rematResultLayout = + rematResultFactor == 1 + ? VMILayoutAttr::getContiguous(resultType.getContext()) + : VMILayoutAttr::getDeinterleaved(resultType.getContext(), + rematResultFactor, + /*blockElems=*/1); + auto rematResultType = + VMIVRegType::get(resultType.getContext(), resultType.getElementCount(), + resultType.getElementType(), rematResultLayout); + if (rematResultType == resultType) + return false; + + OpBuilder builder(trunc); + Value remat = + builder.create(trunc->getLoc(), rematResultType, + ensure.getSource()) + .getResult(); + Value replacement = + materializeDataLayout(remat, resultType, trunc->getLoc(), builder); + trunc.getResult().replaceAllUsesWith(replacement); + trunc.erase(); + return true; +} + +template +static bool tryReplaceMaskEnsure(EnsureOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeMaskProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +struct VMILayoutRematerializePass + : public mlir::pto::impl::VMILayoutRematerializeBase< + VMILayoutRematerializePass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutRematerializePass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + bool changed = true; + while (changed) { + changed = false; + SmallVector helpers; + module.walk([&](Operation *op) { + if (isa(op)) + helpers.push_back(op); + }); + + for (Operation *op : helpers) { + if (op->getBlock() == nullptr) + continue; + + if (auto ensure = dyn_cast(op)) { + changed |= tryReplaceDataEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) { + changed |= tryReplaceMaskEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) + changed |= tryReplaceMaskEnsure(ensure); + + if (auto trunc = dyn_cast(op)) + changed |= tryRematerializeTruncIThroughSourceEnsure(trunc); + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutRematerializePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp new file mode 100644 index 0000000000..3027d919f7 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -0,0 +1,629 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSinkMaterialization.cpp - Sink VMI layout helpers --------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTSINKMATERIALIZATION +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct BinaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct TernaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; + OpOperand *acc = nullptr; +}; + +struct SelectOperands { + OpOperand *mask = nullptr; + OpOperand *trueValue = nullptr; + OpOperand *falseValue = nullptr; +}; + +struct UnaryVRegOperand { + OpOperand *source = nullptr; +}; + +struct BinaryMaskOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct UnaryMaskOperand { + OpOperand *source = nullptr; +}; + +static std::optional getSinkableBinaryOperands(Operation *op) { + if (auto addf = dyn_cast(op)) + return BinaryVRegOperands{&addf.getLhsMutable(), &addf.getRhsMutable()}; + if (auto addi = dyn_cast(op)) + return BinaryVRegOperands{&addi.getLhsMutable(), &addi.getRhsMutable()}; + if (auto subf = dyn_cast(op)) + return BinaryVRegOperands{&subf.getLhsMutable(), &subf.getRhsMutable()}; + if (auto subi = dyn_cast(op)) + return BinaryVRegOperands{&subi.getLhsMutable(), &subi.getRhsMutable()}; + if (auto mulf = dyn_cast(op)) + return BinaryVRegOperands{&mulf.getLhsMutable(), &mulf.getRhsMutable()}; + if (auto muli = dyn_cast(op)) + return BinaryVRegOperands{&muli.getLhsMutable(), &muli.getRhsMutable()}; + if (auto divf = dyn_cast(op)) + return BinaryVRegOperands{&divf.getLhsMutable(), &divf.getRhsMutable()}; + if (auto minf = dyn_cast(op)) + return BinaryVRegOperands{&minf.getLhsMutable(), &minf.getRhsMutable()}; + if (auto maxf = dyn_cast(op)) + return BinaryVRegOperands{&maxf.getLhsMutable(), &maxf.getRhsMutable()}; + if (auto andi = dyn_cast(op)) + return BinaryVRegOperands{&andi.getLhsMutable(), &andi.getRhsMutable()}; + if (auto ori = dyn_cast(op)) + return BinaryVRegOperands{&ori.getLhsMutable(), &ori.getRhsMutable()}; + if (auto xori = dyn_cast(op)) + return BinaryVRegOperands{&xori.getLhsMutable(), &xori.getRhsMutable()}; + if (auto shli = dyn_cast(op)) + return BinaryVRegOperands{&shli.getLhsMutable(), &shli.getRhsMutable()}; + if (auto shrui = dyn_cast(op)) + return BinaryVRegOperands{&shrui.getLhsMutable(), &shrui.getRhsMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableCompareOperands(Operation *op) { + if (auto cmpf = dyn_cast(op)) + return BinaryVRegOperands{&cmpf.getLhsMutable(), &cmpf.getRhsMutable()}; + if (auto cmpi = dyn_cast(op)) + return BinaryVRegOperands{&cmpi.getLhsMutable(), &cmpi.getRhsMutable()}; + return std::nullopt; +} + +static std::optional getSinkableSelectOperands(Operation *op) { + if (auto select = dyn_cast(op)) + return SelectOperands{&select.getMaskMutable(), + &select.getTrueValueMutable(), + &select.getFalseValueMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableTernaryOperands(Operation *op) { + if (auto fma = dyn_cast(op)) + return TernaryVRegOperands{&fma.getLhsMutable(), &fma.getRhsMutable(), + &fma.getAccMutable()}; + return std::nullopt; +} + +static std::optional getSinkableUnaryOperand(Operation *op) { + if (auto negf = dyn_cast(op)) + return UnaryVRegOperand{&negf.getSourceMutable()}; + if (auto absf = dyn_cast(op)) + return UnaryVRegOperand{&absf.getSourceMutable()}; + if (auto absi = dyn_cast(op)) + return UnaryVRegOperand{&absi.getSourceMutable()}; + if (auto sqrt = dyn_cast(op)) + return UnaryVRegOperand{&sqrt.getSourceMutable()}; + if (auto exp = dyn_cast(op)) + return UnaryVRegOperand{&exp.getSourceMutable()}; + if (auto ln = dyn_cast(op)) + return UnaryVRegOperand{&ln.getSourceMutable()}; + if (auto relu = dyn_cast(op)) + return UnaryVRegOperand{&relu.getSourceMutable()}; + if (auto notOp = dyn_cast(op)) + return UnaryVRegOperand{¬Op.getSourceMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableBinaryMaskOperands(Operation *op) { + if (auto maskAnd = dyn_cast(op)) + return BinaryMaskOperands{&maskAnd.getLhsMutable(), + &maskAnd.getRhsMutable()}; + if (auto maskOr = dyn_cast(op)) + return BinaryMaskOperands{&maskOr.getLhsMutable(), + &maskOr.getRhsMutable()}; + if (auto maskXor = dyn_cast(op)) + return BinaryMaskOperands{&maskXor.getLhsMutable(), + &maskXor.getRhsMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableUnaryMaskOperand(Operation *op) { + if (auto maskNot = dyn_cast(op)) + return UnaryMaskOperand{&maskNot.getSourceMutable()}; + return std::nullopt; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp ensure, + VMIVRegType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIEnsureLayoutOp accEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !accEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto accSourceType = dyn_cast(accEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + auto accResultType = dyn_cast(accEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !accSourceType || !lhsResultType || + !rhsResultType || !accResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsSourceType == accSourceType && + lhsResultType == rhsResultType && lhsResultType == accResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeDataLayout(sourceType, resultType)); +} + +template +static bool isSameMaskMaterialization(EnsureOp ensure, VMIMaskType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +template +static bool isSameMaskMaterialization(EnsureOp lhsEnsure, EnsureOp rhsEnsure, + VMIMaskType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool canMaterializeMask(VMIEnsureMaskLayoutOp, VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeMaskLayout(sourceType, resultType)); +} + +static bool canMaterializeMask(VMIEnsureMaskGranularityOp, + VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded( + supports.canMaterializeMaskGranularity(sourceType, resultType)); +} + +static bool trySinkBinaryMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkSelectMaterialization(Operation *op) { + std::optional operands = getSinkableSelectOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto maskEnsure = + operands->mask->get().getDefiningOp(); + auto trueEnsure = + operands->trueValue->get().getDefiningOp(); + auto falseEnsure = + operands->falseValue->get().getDefiningOp(); + if (!maskEnsure || !trueEnsure || !falseEnsure) + return false; + + auto trueSourceType = dyn_cast(trueEnsure.getSource().getType()); + auto falseSourceType = + dyn_cast(falseEnsure.getSource().getType()); + auto trueResultType = dyn_cast(trueEnsure.getResult().getType()); + auto falseResultType = + dyn_cast(falseEnsure.getResult().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!trueSourceType || !falseSourceType || !trueResultType || + !falseResultType || !maskSourceType || !maskResultType) + return false; + + if (trueSourceType != falseSourceType || trueResultType != falseResultType || + trueResultType != resultType || trueSourceType == resultType) + return false; + if (maskResultType != operands->mask->get().getType()) + return false; + if (maskResultType.getLayoutAttr() != resultType.getLayoutAttr() || + maskSourceType.getLayoutAttr() != trueSourceType.getLayoutAttr()) + return false; + if (maskSourceType.getElementCount() != trueSourceType.getElementCount() || + maskResultType.getElementCount() != resultType.getElementCount() || + maskSourceType.getGranularity() != maskResultType.getGranularity()) + return false; + if (!canMaterializeDataLayout(trueSourceType, resultType) || + !canMaterializeMask(maskEnsure, maskSourceType, maskResultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({maskEnsure.getSource(), trueEnsure.getSource(), + falseEnsure.getSource()}); + state.addTypes(trueSourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (maskEnsure->use_empty()) + maskEnsure.erase(); + if (trueEnsure->use_empty()) + trueEnsure.erase(); + if (falseEnsure != trueEnsure && falseEnsure->use_empty()) + falseEnsure.erase(); + return true; +} + +static bool trySinkCompareMaterialization(Operation *op) { + std::optional operands = getSinkableCompareOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultMaskType = dyn_cast(op->getResult(0).getType()); + if (!resultMaskType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!lhsEnsure || !rhsEnsure) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + if (lhsSourceType != rhsSourceType || lhsResultType != rhsResultType || + lhsSourceType == lhsResultType) + return false; + if (lhsResultType.getElementCount() != resultMaskType.getElementCount() || + lhsResultType.getLayoutAttr() != resultMaskType.getLayoutAttr()) + return false; + + auto sourceMaskType = VMIMaskType::get( + op->getContext(), resultMaskType.getElementCount(), + resultMaskType.getGranularity(), lhsSourceType.getLayoutAttr()); + VMILayoutSupport supports; + if (failed(supports.canMaterializeMaskLayout(sourceMaskType, resultMaskType))) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceMaskType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultMaskType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkTernaryMaterialization(Operation *op) { + std::optional operands = getSinkableTernaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + auto accEnsure = operands->acc->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, accEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands( + {lhsEnsure.getSource(), rhsEnsure.getSource(), accEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + if (accEnsure != lhsEnsure && accEnsure != rhsEnsure && + accEnsure->use_empty()) + accEnsure.erase(); + return true; +} + +template +static bool trySinkBinaryMaskMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryMaskOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaskMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeMask(lhsEnsure, sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkUnaryMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +template +static bool trySinkUnaryMaskMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryMaskOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaskMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeMask(sourceEnsure, sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +static bool trySinkMaskMaterialization(Operation *op) { + return trySinkBinaryMaskMaterialization(op) || + trySinkBinaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op); +} + +struct VMILayoutSinkMaterializationPass + : public mlir::pto::impl::VMILayoutSinkMaterializationBase< + VMILayoutSinkMaterializationPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + VMILayoutSinkMaterializationPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector candidates; + module.walk([&](Operation *op) { + if (getSinkableBinaryOperands(op) || getSinkableCompareOperands(op) || + getSinkableSelectOperands(op) || getSinkableTernaryOperands(op) || + getSinkableUnaryOperand(op) || getSinkableBinaryMaskOperands(op) || + getSinkableUnaryMaskOperand(op)) + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + if (op->getBlock() == nullptr) + continue; + if (!trySinkBinaryMaterialization(op)) { + if (!trySinkCompareMaterialization(op)) { + if (!trySinkSelectMaterialization(op)) { + if (!trySinkTernaryMaterialization(op)) { + if (!trySinkUnaryMaterialization(op)) + trySinkMaskMaterialization(op); + } + } + } + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutSinkMaterializationPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp new file mode 100644 index 0000000000..0a3649fd32 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -0,0 +1,2057 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSupport.cpp - VMI layout support queries --------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VMILayoutSupport.h" + +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/ADT/Twine.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult failWithReason(const Twine &message, std::string *reason) { + if (reason) + *reason = message.str(); + return failure(); +} + +static LogicalResult checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return failWithReason("requires known physical lanes per part", reason); + + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failWithReason("requires computable physical arity", reason); + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failWithReason("requires assigned layout", reason); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + if (factor <= 0 || *arity % factor != 0) + return failWithReason("requires arity divisible by layout factor", reason); + + int64_t chunksPerPart = *arity / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failWithReason("failed to map physical padding lane", reason); + if (*padding) + return failWithReason("found padding lane in physical chunk", reason); + } + } + } + + return success(); +} + +static bool hasX2MemoryDistToken(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool hasDenseLaneStride2UnpackedLoad(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool hasDenseLaneStride4UnpackedLoad(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8; +} + +static bool hasDenseLaneStride2PackedStore(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool hasDenseLaneStride4PackedStore(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8; +} + +static bool hasDenseLaneStridePackUnpackElement(Type elementType, + int64_t laneStride) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits == 0 || (!isa(elementType) && + !isa(elementType))) + return false; + if (laneStride == 2) + return elementBits == 8 || elementBits == 16; + if (laneStride == 4) + return elementBits == 8; + return false; +} + +static std::optional +getDenseLaneStrideMaskedStoreMaskGranularity(VMIVRegType valueType) { + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + +static StringRef getMaskGranularityForElementBits(unsigned elementBits) { + switch (elementBits) { + case 8: + return "b8"; + case 16: + return "b16"; + case 32: + return "b32"; + default: + return ""; + } +} + +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (constant.getType().isIndex()) + return constant.value(); + } + return std::nullopt; +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + assert(lhs >= 0 && rhs > 0); + return (lhs + rhs - 1) / rhs; +} + +static FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getVMITypeLayoutFactor(Type type) { + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return failure(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + if (!layout) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + int64_t laneStride = layout.isDense() ? layout.getLaneStride() : 1; + int64_t physicalLanes = + logicalLanesInPart == 0 ? 0 : (logicalLanesInPart - 1) * laneStride + 1; + return ceilDivNonNegative(physicalLanes, *lanesPerPart); +} + +static LogicalResult checkFullVMIPhysicalChunks(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +static FailureOr +getContiguousMaterializationPartCount(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return fail("requires VMI data or mask type"); + + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous() && layout.getLaneStride() == 1) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +static LogicalResult checkLayoutMaterializationShape(Type sourceType, + Type resultType, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); +} + +static FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +static FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr> +getPhysicalLogicalBitFootprint(VMIVRegType type) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (elementBits == 0) + return failure(); + + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(factor) || failed(lanesPerPart) || failed(arity) || *factor <= 0) + return failure(); + + SmallVector bits; + bits.reserve(*arity); + for (int64_t part = 0; part < *factor; ++part) { + for (int64_t chunk = 0; chunk < *arity; ++chunk) { + int64_t activeLanes = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++activeLanes; + } + if (activeLanes > 0) + bits.push_back(activeLanes * static_cast(elementBits)); + } + } + if (static_cast(bits.size()) != *arity) + return failure(); + return bits; +} + +static FailureOr +getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::Identity}; + if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::ContiguousToDeinterleaved}; + if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::DeinterleavedToContiguous}; + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind:: + DeinterleavedToDeinterleavedViaContiguous}; + return fail("unsupported source/result layout pair"); +} + +} // namespace + +FailureOr +VMILayoutSupport::getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, + int64_t numGroups, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + + MLIRContext *ctx = sourceType.getContext(); + int64_t vlaneElems = *lanesPerPart / 8; + VMIGroupReduceLayoutFact fact; + fact.groupSize = *groupSize; + fact.lanesPerPart = *lanesPerPart; + fact.vlaneElems = vlaneElems; + + if (*groupSize == vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::OneVLane; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 2 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::TwoVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 4 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::FourVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize >= *lanesPerPart && *groupSize % *lanesPerPart == 0) { + fact.kind = VMIGroupReduceLayoutKind::RowLocal; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + return fact; + } + + return fail("group_reduce layout supports group sizes VLaneElems, " + "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); +} + +static FailureOr +getBaselineCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceBits == 0 || resultBits == 0) + return fail( + "requires source/result element types with known storage width"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + + MLIRContext *ctx = sourceType.getContext(); + VMICastLayoutFact fact; + fact.sourceBits = sourceBits; + fact.resultBits = resultBits; + + if ((sourceBits == 8 || sourceBits == 16) && + resultBits == sourceBits * 2) { + fact.kind = VMICastLayoutKind::Widen2x; + fact.factor = 2; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if (resultBits == 32 && sourceBits == 8) { + fact.kind = VMICastLayoutKind::Widen4x; + fact.factor = 4; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if ((resultBits == 8 || resultBits == 16) && + sourceBits == resultBits * 2) { + fact.kind = VMICastLayoutKind::Narrow2x; + fact.factor = 2; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + if (sourceBits == 32 && resultBits == 8) { + fact.kind = VMICastLayoutKind::Narrow4x; + fact.factor = 4; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + return fail("supports only 8/16-bit integer widening and 32-bit integer " + "narrowing dense cast layout facts"); +} + +FailureOr VMILayoutSupport::getPreferredCastLayoutFact( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + FailureOr baseline = + getBaselineCastLayoutFact(sourceType, resultType, reason); + if (failed(baseline)) + return baseline; + + bool isWiden = baseline->kind == VMICastLayoutKind::Widen2x || + baseline->kind == VMICastLayoutKind::Widen4x; + bool isNarrow = baseline->kind == VMICastLayoutKind::Narrow2x || + baseline->kind == VMICastLayoutKind::Narrow4x; + if (!isWiden && !isNarrow) + return baseline; + + MLIRContext *ctx = sourceType.getContext(); + VMILayoutAttr compactSourceLayout = isWiden + ? VMILayoutAttr::getContiguous( + ctx, baseline->factor) + : VMILayoutAttr::getContiguous(ctx); + VMILayoutAttr compactResultLayout = isWiden + ? VMILayoutAttr::getContiguous(ctx) + : VMILayoutAttr::getContiguous( + ctx, baseline->factor); + VMIVRegType compactSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), compactSourceLayout); + VMIVRegType compactResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), compactResultLayout); + FailureOr compactSourceArity = + getVMIPhysicalArity(compactSourceType); + FailureOr compactResultArity = + getVMIPhysicalArity(compactResultType); + if (failed(compactSourceArity) || failed(compactResultArity) || + *compactSourceArity != *compactResultArity) + return baseline; + + if (isWiden) { + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); + FailureOr baselineSourceArity = + getVMIPhysicalArity(baselineSourceType); + FailureOr baselineResultArity = + getVMIPhysicalArity(baselineResultType); + if (failed(baselineSourceArity) || failed(baselineResultArity) || + *compactSourceArity > *baselineSourceArity || + *compactResultArity >= *baselineResultArity) + return baseline; + } else if (!sourceType.getElementType().isF32() || + !isa(resultType.getElementType())) { + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); + FailureOr baselineSourceArity = + getVMIPhysicalArity(baselineSourceType); + FailureOr baselineResultArity = + getVMIPhysicalArity(baselineResultType); + if (failed(baselineSourceArity) || failed(baselineResultArity) || + *compactResultArity > *baselineResultArity || + *compactSourceArity >= *baselineSourceArity) + return baseline; + } else { + VMICastLayoutFact best = *baseline; + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); + FailureOr bestSourceArity = getVMIPhysicalArity(baselineSourceType); + FailureOr bestResultArity = getVMIPhysicalArity(baselineResultType); + if (failed(bestSourceArity) || failed(bestResultArity)) + return baseline; + int64_t bestCost = *bestSourceArity + *bestResultArity; + + for (int64_t sourceFactor = 1; sourceFactor <= baseline->factor; + sourceFactor *= 2) { + if (baseline->factor % sourceFactor != 0) + continue; + int64_t resultLaneStride = baseline->factor / sourceFactor; + if (resultLaneStride != 1 && + !hasDenseLaneStridePackUnpackElement(resultType.getElementType(), + resultLaneStride)) + continue; + + VMILayoutAttr sourceLayout = + sourceFactor == 1 + ? VMILayoutAttr::getContiguous(ctx) + : VMILayoutAttr::getDeinterleaved(ctx, sourceFactor, + /*blockElems=*/1); + VMILayoutAttr resultLayout = + VMILayoutAttr::getContiguous(ctx, resultLaneStride); + VMIVRegType candidateSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), sourceLayout); + VMIVRegType candidateResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + FailureOr candidateSourceArity = + getVMIPhysicalArity(candidateSourceType); + FailureOr candidateResultArity = + getVMIPhysicalArity(candidateResultType); + if (failed(candidateSourceArity) || failed(candidateResultArity) || + *candidateSourceArity != sourceFactor * *candidateResultArity) + continue; + + int64_t candidateCost = *candidateSourceArity + *candidateResultArity; + if (candidateCost >= bestCost) + continue; + + best = *baseline; + best.sourceLayout = sourceLayout; + best.resultLayout = resultLayout; + bestCost = candidateCost; + } + return best; + } + + VMICastLayoutFact compact = *baseline; + compact.sourceLayout = compactSourceLayout; + compact.resultLayout = compactResultLayout; + return compact; +} + +FailureOr +VMILayoutSupport::getWidenSourceLayoutForResultLayout( + VMIVRegType sourceType, VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!requestedResultLayout) + return fail("requires requested result layout"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + if (requestedResultLayout.isGroupSlots()) + return fail("dense widen relation does not support group_slots layout"); + if (!requestedResultLayout.isContiguous() && + (!requestedResultLayout.isDeinterleaved() || + requestedResultLayout.getBlockElems() != 1)) + return fail("requires contiguous or deinterleaved result layout with " + "block_elems=1"); + + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("requires supported 8/16-bit to 32-bit widen cast"); + + if (requestedResultLayout.isContiguous()) { + if (!fact->resultLayout.isContiguous() || + fact->resultLayout.getLaneStride() != + requestedResultLayout.getLaneStride()) + return fail("requested contiguous result layout is not the natural " + "compact widen result layout"); + return VMILayoutAttr::getContiguous(sourceType.getContext(), + /*laneStride=*/fact->factor); + } + + int64_t resultFactor = requestedResultLayout.isDeinterleaved() + ? requestedResultLayout.getFactor() + : 1; + if (resultFactor % fact->factor != 0) + return fail("requested result layout factor is not divisible by widen " + "factor"); + + int64_t sourceFactor = resultFactor / fact->factor; + if (sourceFactor == 1) + return VMILayoutAttr::getContiguous(sourceType.getContext()); + if (sourceFactor == 2 || sourceFactor == 4) + return VMILayoutAttr::getDeinterleaved(sourceType.getContext(), + sourceFactor, /*blockElems=*/1); + return fail("derived source layout factor is unsupported"); +} + +FailureOr +VMILayoutSupport::getContiguousLoadSupport(VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return fail("requires assigned result layout"); + if (!layout.isContiguous()) + return fail("requires contiguous result layout"); + if (layout.getLaneStride() == 1) + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::ContiguousVlds}; + if (layout.getLaneStride() == 2) { + if (!hasDenseLaneStride2UnpackedLoad(resultType.getElementType())) + return fail("requires 8/16/32-bit element type for dense lane_stride=2 " + "unpacked load"); + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::LaneStride2UnpackedVlds}; + } + if (layout.getLaneStride() == 4) { + if (!hasDenseLaneStride4UnpackedLoad(resultType.getElementType())) + return fail("requires 8-bit element type for dense lane_stride=4 " + "unpacked load"); + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::LaneStride4UnpackedVlds}; + } + return fail("requires lane_stride 1, 2, or 4 for contiguous load"); +} + +FailureOr +VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout) + return fail("requires assigned value layout"); + if (layout.isContiguous() && layout.getLaneStride() == 1) + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::ContiguousVsts}; + if (layout.isContiguous() && layout.getLaneStride() == 2) { + if (!hasDenseLaneStride2PackedStore(valueType.getElementType())) + return fail("requires 8/16/32-bit element type for dense lane_stride=2 " + "packed store"); + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::LaneStride2PackedVsts}; + } + if (layout.isContiguous() && layout.getLaneStride() == 4) { + if (!hasDenseLaneStride4PackedStore(valueType.getElementType())) + return fail("requires 8-bit element type for dense lane_stride=4 " + "packed store"); + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::LaneStride4PackedVsts}; + } + if (!layout.isDeinterleaved()) + return fail("requires contiguous or deinterleaved value layout"); + if (layout.getLaneStride() != 1) + return fail("deinterleaved packed store requires lane_stride=1"); + if (layout.getBlockElems() != 1) + return fail("requires block_elems=1 deinterleaved value layout"); + if (failed(checkFullDataPhysicalChunks(valueType, reason))) + return failure(); + + if (layout.getFactor() == 2) { + if (!hasX2MemoryDistToken(valueType.getElementType())) + return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2}; + } + + if (layout.getFactor() == 4) + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::DeinterleavedMaterializeThenVsts}; + + return fail("requires deinterleaved factor 2 or 4"); +} + +LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (sourceType.getElementType() != resultType.getElementType()) + return failWithReason("source/result element types must match", reason); + if (sourceType.getElementCount() != resultType.getElementCount()) + return failWithReason("source/result element counts must match", reason); + + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return failWithReason("result layout must be contiguous", reason); + + FailureOr support = + getContiguousStoreSupport(sourceType, reason); + if (failed(support)) + return failure(); + if (support->kind == VMIContiguousStoreSupportKind::ContiguousVsts) + return failWithReason("source layout is already contiguous", reason); + + return success(); +} + +LogicalResult VMILayoutSupport::canFoldContiguousMaskedStoreMaterialization( + VMIVRegType sourceType, VMIMaskType maskSourceType, + VMIVRegType resultType, VMIMaskType maskResultType, + std::string *reason) const { + if (sourceType.getElementType() != resultType.getElementType()) + return failWithReason("source/result element types must match", reason); + if (sourceType.getElementCount() != resultType.getElementCount()) + return failWithReason("source/result element counts must match", reason); + if (maskSourceType.getElementCount() != sourceType.getElementCount() || + maskResultType.getElementCount() != resultType.getElementCount()) + return failWithReason("value/mask element counts must match", reason); + if (maskSourceType.getGranularity() != maskResultType.getGranularity()) + return failWithReason("mask layout fold cannot change granularity", reason); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskSourceLayout = maskSourceType.getLayoutAttr(); + VMILayoutAttr maskResultLayout = maskResultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !maskSourceLayout || !maskResultLayout) + return failWithReason("requires assigned value/mask layouts", reason); + if (!resultLayout.isContiguous() || !maskResultLayout.isContiguous()) + return failWithReason("result value/mask layouts must be contiguous", + reason); + if (sourceLayout != maskSourceLayout) + return failWithReason("source value/mask layouts must match", reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskSourceType); + if (failed(sourceArity) || failed(maskArity) || *sourceArity != *maskArity) + return failWithReason("source value/mask physical arity must match", + reason); + + if (!sourceLayout.hasDenseLaneStride()) + return canFoldContiguousStoreMaterialization(sourceType, resultType, + reason); + + std::optional packedGranularity = + getDenseLaneStrideMaskedStoreMaskGranularity(sourceType); + if (!packedGranularity) + return failWithReason("dense lane_stride masked store supports only " + "LS=2 b8/b16 and LS=4 b8 compact masks", + reason); + + unsigned elementBits = pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + StringRef expectedSourceGranularity = + getMaskGranularityForElementBits(elementBits); + if (expectedSourceGranularity.empty() || + maskSourceType.getGranularity() != expectedSourceGranularity) + return failWithReason("mask granularity must match source element width", + reason); + + return success(); +} + +FailureOr +VMILayoutSupport::getDataLayoutMaterializationSupport( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementType() != resultType.getElementType()) + return fail("source/result element types must match"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result element counts must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (succeeded(support)) { + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) + return failure(); + return support; + } + + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() != 1) { + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("dense lane_stride register materialization currently " + "requires source and result to have the same physical arity"); + if (!hasDenseLaneStridePackUnpackElement(sourceType.getElementType(), + resultLayout.getLaneStride())) + return fail("requires bitcastable 8/16-bit element type for dense " + "lane_stride register unpack materialization"); + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::ContiguousToLaneStrideViaUnpack}; + } + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() != 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) { + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("dense lane_stride register materialization currently " + "requires source and result to have the same physical arity"); + if (!hasDenseLaneStridePackUnpackElement(sourceType.getElementType(), + sourceLayout.getLaneStride())) + return fail("requires bitcastable 8/16-bit element type for dense " + "lane_stride register pack materialization"); + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::LaneStrideToContiguousViaPack}; + } + + return failure(); +} + +LogicalResult VMILayoutSupport::canMaterializeDataLayout( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (failed( + getDataLayoutMaterializationSupport(sourceType, resultType, reason))) + return failure(); + return success(); +} + +FailureOr +VMILayoutSupport::getMaskLayoutMaterializationSupport( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getGranularity() != resultType.getGranularity()) + return fail("source/result mask granularities must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (failed(support)) + return failure(); + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) + return failure(); + return support; +} + +LogicalResult VMILayoutSupport::canMaterializeMaskLayout( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + if (failed( + getMaskLayoutMaterializationSupport(sourceType, resultType, reason))) + return failure(); + return success(); +} + +FailureOr +VMILayoutSupport::getMaskGranularityMaterializationSupport( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("source/result mask layouts must match"); + if (!VMIMaskType::isConcreteGranularity(sourceType.getGranularity()) || + !VMIMaskType::isConcreteGranularity(resultType.getGranularity())) + return fail("requires concrete b8/b16/b32 source and result granularities"); + if (sourceType.getGranularity() == resultType.getGranularity()) + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::Identity}; + + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::PredicateCast}; +} + +LogicalResult VMILayoutSupport::canMaterializeMaskGranularity( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + if (failed(getMaskGranularityMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); +} + +FailureOr VMILayoutSupport::getGroupSlotLoadSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (resultType.getElementCount() != numGroups) + return fail("requires result logical lane count to match num_groups"); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != numGroups || layout.getSlots() <= 0) + return fail("requires explicit group_slots result layout matching " + "num_groups"); + + if (layout.getSlots() != 8 && layout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_slot_load layouts"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for vsldb lowering"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (layout.getSlots() == 8) { + if (!stride || *stride != 1) + return fail("slots=8 group_slot_load requires constant unit " + "source_group_stride"); + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb}; + } + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_slot_load requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_slot_load currently lowers as one " + "lane-0 vsldb per group and requires constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B load alignment; packed or unaligned " + "scalar load lowering is not implemented"); + + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots1AlignedLane0Vsldb}; +} + +FailureOr VMILayoutSupport::getGroupLoadSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return fail("requires deinterleaved block8 f32 result layout"); + + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + + if ((*groupSize != 16 || layout.getFactor() != 2) && + (*groupSize != 32 || layout.getFactor() != 4)) + return fail("block8 strided group_load requires S=16/factor=2 or " + "S=32/factor=4"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("block8 strided group_load requires !pto.ptr source"); + + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return fail("block8 strided group_load requires num_groups multiple of 8"); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return fail("block8 strided group_load requires constant positive " + "row_stride divisible by 8 f32 elements"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("block8 strided group_load requires full physical " + "result chunks; ") + + fullChunkReason); + + if (*groupSize == 16) + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S16Block8Vsldb}; + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S32Block8Vsldb}; +} + +FailureOr +VMILayoutSupport::getGroupSlotsStoreSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return fail("requires group_slots value layout"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMICapabilityResult elementCapability = capabilities.supportsElementType( + valueType.getElementType(), VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(valueType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + if (layout.getSlots() == 1) { + if (*arity != numGroups) + return fail("slots=1 group_store requires one physical part per " + "group"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_store requires an 8/16/32-bit element " + "type"); + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueType.getElementType()); + if (rowStride && *rowStride == 1 && succeeded(lanesPerPart) && + numGroups <= *lanesPerPart) + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots1PackedUnitStrideVsts}; + if (rowStride && *rowStride <= 0) + return fail("slots=1 group_store requires positive row_stride when " + "row_stride is constant"); + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots1PointVsts}; + } + + if (layout.getSlots() == 8) { + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride != 1) + return fail("slots=8 group_store currently requires constant unit " + "row_stride"); + if (*arity != ceilDivNonNegative(numGroups, 8)) + return fail("slots=8 group_store arity must equal ceil(num_groups / " + "8)"); + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts}; + } + + return fail("group_slots group_store currently supports only slots=1 or " + "unit-stride slots=8"); +} + +FailureOr getGroupReduceAddSupportImpl( + const VMITargetCapabilityRegistry &capabilities, Operation *op, + VMIVRegType sourceType, VMIMaskType maskType, VMIVRegType resultType, + int64_t numGroups, bool requiresReassoc, VMIReductionKind reductionKind, + std::string *reason) { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (requiresReassoc && !op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point " + "reduction"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups) + return fail("requires group_slots result layout matching num_groups"); + if (resultLayout.getSlots() != 8 && resultLayout.getSlots() != 1) { + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + int64_t vlaneElems = succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 + ? *lanesPerPart / 8 + : -1; + if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && + (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && + *groupSize != 4 * vlaneElems)) + return fail("stable group_reduce_add slots=8 support group " + "sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems"); + return fail("stable group_reduce_add layout support currently requires " + "result layout slots=8 or slots=1"); + } + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(reductionKind, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("stable group_reduce_add layout support requires matching " + "source/result element types"); + if (resultType.getElementCount() != numGroups) + return fail("requires result lane count to match num_groups"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + int64_t vlaneElems = *lanesPerPart / 8; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source/mask/result physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires matching non-empty source/mask physical arity"); + + if (resultLayout.getSlots() == 1) { + if (failed(lanesPerPart) || *groupSize < *lanesPerPart || + *groupSize % *lanesPerPart != 0) + return fail("stable group_reduce_add slots=1 support group " + "sizes that are multiples of one physical chunk"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("slots=1 group_reduce_add requires contiguous source/mask " + "layouts"); + if (*resultArity != numGroups) + return fail("slots=1 group_reduce_add requires one physical result " + "part per group"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("slots=1 group_reduce_add requires full source " + "chunks; ") + + sourceFullReason); + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::ContiguousVcaddRows}; + } + + if (*groupSize == vlaneElems) { + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("one-vlane group_reduce_add requires contiguous source/mask " + "layouts"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("one-vlane group_reduce_add requires full source " + "chunks; ") + + sourceFullReason); + if (*resultArity != *sourceArity) + return fail("one-vlane group_reduce_add requires source/result physical " + "arity to match"); + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd}; + } + + if (*groupSize == 2 * vlaneElems) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("two-vlane group_reduce_add requires source layout " + "deinterleaved=2 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("two-vlane group_reduce_add requires matching mask layout " + "deinterleaved=2 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2) + return fail( + "two-vlane group_reduce_add requires two source/mask parts per " + "result part"); + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd}; + } + + if (*groupSize == 4 * vlaneElems) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("four-vlane group_reduce_add requires source layout " + "deinterleaved=4 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("four-vlane group_reduce_add requires matching mask layout " + "deinterleaved=4 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4) + return fail( + "four-vlane group_reduce_add requires four source/mask parts per " + "result part"); + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree}; + } + + return fail("stable group_reduce_add slots=8 support group sizes " + "VLaneElems, 2*VLaneElems, or 4*VLaneElems"); +} + +FailureOr +VMILayoutSupport::getGroupReduceAddFSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/true, + VMIReductionKind::GroupAddF, reason); +} + +FailureOr +VMILayoutSupport::getGroupReduceMaxFSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceMaxFOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupMaxF, reason); +} + +FailureOr +VMILayoutSupport::getGroupReduceAddISupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupAddI, reason); +} + +FailureOr +VMILayoutSupport::getGroupReduceMaxISupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceMaxIOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupMaxI, reason); +} + +FailureOr VMILayoutSupport::getGroupBroadcastSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason) const { + return getGroupBroadcastSupport(capabilities, + cast(op.getSource().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), reason); +} + +FailureOr +VMILayoutSupport::getGroupBroadcastLoadSupport( + const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastLoadOp op, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0) + return fail("requires positive num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide result lane count"); + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source"); + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return fail("requires assigned result layout"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known result lanes per physical part"); + + int64_t groupSize = resultType.getElementCount() / numGroups; + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + + bool contiguousPacketLayout = layout.isContiguous(); + bool splitPacketLayout = layout.isDeinterleaved() && layout.getFactor() == 2 && + layout.getBlockElems() == 1; + if ((elementBits == 16 || elementBits == 32) && + *lanesPerPart == static_cast(2048 / elementBits) && + (contiguousPacketLayout || splitPacketLayout) && numGroups % 8 == 0 && + stride && *stride == 1) { + int64_t directGroupSize = 256 / elementBits; + if ((contiguousPacketLayout && groupSize == directGroupSize) || + (splitPacketLayout && groupSize == 2 * directGroupSize)) + return VMIGroupBroadcastLoadSupport{ + VMIGroupBroadcastLoadSupportKind::E2BVlds}; + } + + if (elementBits == 0 || 256 % elementBits != 0) + return fail("fallback lowering requires an 8/16/32-bit element type"); + int64_t alignedStrideElems = 256 / elementBits; + int64_t slots = 0; + if (stride && *stride == 1) + slots = 8; + else if (stride && *stride > 0 && *stride % alignedStrideElems == 0) + slots = 1; + else + return fail(Twine("fallback lowering requires constant unit " + "source_group_stride for packed slots or constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + " elements for lane-0 slots"); + + auto sourceType = VMIVRegType::get( + resultType.getContext(), numGroups, resultType.getElementType(), + VMILayoutAttr::getGroupSlots(resultType.getContext(), numGroups, slots)); + std::string broadcastReason; + if (failed(getGroupBroadcastSupport(capabilities, sourceType, resultType, + numGroups, &broadcastReason))) + return fail(Twine("fallback broadcast is unsupported; ") + + broadcastReason); + return VMIGroupBroadcastLoadSupport{ + VMIGroupBroadcastLoadSupportKind::SlotLoadThenBroadcast}; +} + +FailureOr VMILayoutSupport::getGroupBroadcastSupport( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType sourceType, + VMIVRegType resultType, int64_t numGroups, std::string *reason) const { + (void)capabilities; + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementType() != resultType.getElementType()) + return fail("requires source/result element type to match"); + if (sourceType.getElementCount() != numGroups) + return fail("requires source lane count to match num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide result lane count"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != numGroups) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(resultType, numGroups, reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; + + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; + + bool deinterleavedSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 1 && + *groupSize < *lanesPerPart && *groupSize >= *resultFactor && + *groupSize % *resultFactor == 0 && + *lanesPerPart % (*groupSize / *resultFactor) == 0; + if (deinterleavedSmallGroup) + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; + + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; +} + +FailureOr +VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceType.getElementType().isF32() || resultBits != 16 || + *sourceArity != *resultArity) + return fail("group-slot truncf requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "f32 source, f16 result, and matching physical arity"); + return VMITruncFSupport{VMITruncFSupportKind::GroupSlots1F32ToF16}; + } + + if (!sourceType.getElementType().isF32()) + return fail("requires f32 source"); + + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); + + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (((sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1) || + (sourceLayout.isDeinterleaved() && sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) && + resultLayout.isContiguous() && resultLayout.getLaneStride() > 0 && + sourceFactor * resultLayout.getLaneStride() == fact->factor && + *sourceArity == sourceFactor * *resultArity) { + if (fact->kind == VMICastLayoutKind::Narrow2x) + return VMITruncFSupport{ + resultLayout.getLaneStride() == 1 + ? VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16 + : VMITruncFSupportKind::ContiguousF32ToLaneStrideF16}; + if (fact->kind == VMICastLayoutKind::Narrow4x) + return VMITruncFSupport{ + resultLayout.getLaneStride() == 1 + ? VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8 + : VMITruncFSupportKind::ContiguousF32ToLaneStrideF8}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + resultLayout.getLaneStride() != 1) + return fail("requires f32 deinterleaved source and contiguous result, or " + "contiguous source and lane_stride narrowing result"); + + if (fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor * *resultArity) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16}; + if (fact->kind == VMICastLayoutKind::Narrow4x && + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor * *resultArity) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8}; + + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); +} + +FailureOr +VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported extf source element width, result factor, or " + "physical arity"); + + if (sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == fact->factor && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + *sourceArity == *resultArity && resultType.getElementType().isF32()) { + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; + } + + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + resultLayout.getLaneStride() != 1 || + !(sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && + sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) || + !resultType.getElementType().isF32()) + return fail("requires contiguous or deinterleaved source layout and " + "deinterleaved f32 result layout with block_elems=1"); + + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (resultLayout.getFactor() != sourceFactor * fact->factor || + *resultArity != fact->factor * *sourceArity) + return fail("unsupported extf source/result layout factor or physical " + "arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; + + return fail("unsupported extf source element width"); +} + +template +static FailureOr getExtISupportImpl(OpT op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + + FailureOr fact = + VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, + reason); + + if (sourceLayout.isGroupSlots() && resultLayout.isGroupSlots()) { + if (!isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires integer source/result element types"); + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 8 || resultLayout.getSlots() != 8) + return fail("requires matching group_slots(num_groups=G, slots=8) " + "source/result layouts"); + if (*sourceArity != *resultArity) + return fail("group_slots integer extension requires matching physical " + "arity"); + + unsigned sourceBits = pto::getPTOStorageElemBitWidth( + sourceType.getElementType()); + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultType.getElementType()); + if (resultBits != 32) + return fail("group_slots integer extension requires 32-bit result " + "element type"); + if (sourceBits == 16) + return VMIExtISupport{VMIExtISupportKind::GroupSlotsI16ToI32}; + if (sourceBits == 8) + return VMIExtISupport{VMIExtISupportKind::GroupSlotsI8ToI32}; + return fail("group_slots integer extension source must be 8-bit or " + "16-bit"); + } + + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x) && + sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == fact->factor && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + *sourceArity == *resultArity && + isa(sourceType.getElementType()) && + isa(resultType.getElementType())) { + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; + } + + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + resultLayout.getLaneStride() != 1 || + !(sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && + sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) || + !isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires contiguous or deinterleaved integer source layout " + "and deinterleaved integer result layout with block_elems=1"); + + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported integer extension source/result element width, " + "result factor, or physical arity"); + + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (resultLayout.getFactor() != sourceFactor * fact->factor || + *resultArity != fact->factor * *sourceArity) + return fail("unsupported integer extension source/result layout factor or " + "physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtISupport{VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtISupport{VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; + + return fail("unsupported integer extension source/result element width"); +} + +FailureOr +VMILayoutSupport::getExtSISupport(VMIExtSIOp op, std::string *reason) const { + return getExtISupportImpl(op, reason); +} + +FailureOr +VMILayoutSupport::getExtUISupport(VMIExtUIOp op, std::string *reason) const { + return getExtISupportImpl(op, reason); +} + +FailureOr +VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires integer source and result element types"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != resultLayout.getSlots() || + (sourceLayout.getSlots() != 1 && sourceLayout.getSlots() != 8) || + sourceBits != 32 || (resultBits != 16 && resultBits != 8) || + *sourceArity != *resultArity) + return fail("group-slot trunci requires matching " + "group_slots(num_groups=G, slots=1 or 8) source/result layouts, " + "32-bit integer source, 8/16-bit integer result, and " + "matching physical arity"); + return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToNarrow}; + } + + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved trunci factor, arity, result " + "element width, or result signedness; 8-bit integer narrowing " + "requires unsigned i8 result"); + + if (!sourceLayout.isDeinterleaved() || sourceLayout.getBlockElems() != 1 || + !((resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) || + (resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1 && + resultLayout.getLaneStride() == 1))) + if (!(sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == fact->factor)) + return fail("requires integer deinterleaved source and contiguous or " + "deinterleaved integer result with block_elems=1, or " + "contiguous source and lane_stride narrowing result"); + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == fact->factor && + *sourceArity == *resultArity) { + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return fail("8-bit integer narrowing requires unsigned i8 result"); + if (fact->kind == VMICastLayoutKind::Narrow2x) + return VMITruncISupport{ + VMITruncISupportKind::ContiguousI32ToLaneStrideI16}; + if (fact->kind == VMICastLayoutKind::Narrow4x) + return VMITruncISupport{ + VMITruncISupportKind::ContiguousI32ToLaneStrideI8}; + } + + int64_t resultFactor = + resultLayout.isDeinterleaved() ? resultLayout.getFactor() : 1; + if (sourceLayout.getFactor() != resultFactor * fact->factor || + *sourceArity != *resultArity * fact->factor) + return fail("unsupported deinterleaved trunci source/result layout factor " + "or physical arity"); + + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return fail("8-bit integer narrowing requires unsigned i8 result"); + + if (fact->kind == VMICastLayoutKind::Narrow2x) + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16}; + if (fact->kind == VMICastLayoutKind::Narrow4x) + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved4I32ToContiguousI8}; + + return fail("unsupported deinterleaved trunci factor, arity, result element " + "width, or result signedness; 8-bit integer narrowing " + "requires unsigned i8 result"); +} + +FailureOr +VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source and result layouts"); + if (sourceLayout != resultLayout) + return fail("requires matching source and result layouts"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source and result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + FailureOr> sourceBits = + getPhysicalLogicalBitFootprint(sourceType); + FailureOr> resultBits = + getPhysicalLogicalBitFootprint(resultType); + if (failed(sourceBits) || failed(resultBits)) + return fail("requires computable physical logical bit footprints"); + if (sourceBits->size() != resultBits->size()) + return fail("requires source and result physical footprint counts to " + "match"); + for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { + if (source != result) + return fail("requires matching logical bit footprint in every physical " + "chunk"); + } + + return VMIBitcastSupport{VMIBitcastSupportKind::PerPartVbitcast}; +} + +template +static FailureOr +getHistogramSupportImpl(OpTy op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + VMILayoutAttr accLayout = accType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!accLayout || !sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned acc/source/mask/result layouts"); + if (!accLayout.isContiguous() || !sourceLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous acc, source, mask, and result layouts"); + if (maskType.getGranularity() != "b8") + return fail("requires b8 mask granularity"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return fail("requires mask lane count to match source lane count"); + + auto accElem = dyn_cast(accType.getElementType()); + auto sourceElem = dyn_cast(sourceType.getElementType()); + if (!accElem || !accElem.isUnsigned() || accElem.getWidth() != 16 || + accType.getElementCount() != 256 || resultType != accType) + return fail("requires contiguous 256xui16 acc/result"); + if (!sourceElem || !sourceElem.isUnsigned() || sourceElem.getWidth() != 8) + return fail("requires unsigned 8-bit source elements"); + + FailureOr accArity = getVMIPhysicalArity(accType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(accArity) || failed(resultArity) || failed(sourceArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*accArity != 2 || *resultArity != 2) + return fail("requires acc/result to physicalize to two 128xui16 parts"); + if (*sourceArity != *maskArity) + return fail("requires source and mask physical arity to match"); + if (*sourceArity < 1) + return fail("requires at least one source physical chunk"); + + return VMIHistogramSupport{VMIHistogramSupportKind::Full256BinDhist}; +} + +FailureOr +VMILayoutSupport::getDhistSupport(VMIDhistOp op, std::string *reason) const { + return getHistogramSupportImpl(op, reason); +} + +FailureOr +VMILayoutSupport::getChistSupport(VMIChistOp op, std::string *reason) const { + if (reason) + *reason = "CHISTv2 cumulative high-range semantics are not classified"; + return failure(); +} diff --git a/lib/PTO/Transforms/VMILegalizeArithSelect.cpp b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp new file mode 100644 index 0000000000..471215985f --- /dev/null +++ b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILegalizeArithSelect.cpp - Legalize VMI arith.select ------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILEGALIZEARITHSELECT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isVMIValueType(Type type) { + return isa(type); +} + +static bool hasScalarI1Condition(arith::SelectOp select) { + return select.getCondition().getType().isSignlessInteger(1); +} + +static void rewriteSelectToIf(arith::SelectOp select) { + OpBuilder builder(select); + auto ifOp = builder.create( + select.getLoc(), TypeRange{select.getResult().getType()}, + select.getCondition(), /*withElseRegion=*/true); + + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(select.getLoc(), select.getTrueValue()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(select.getLoc(), select.getFalseValue()); + } + + select.getResult().replaceAllUsesWith(ifOp.getResult(0)); + select.erase(); +} + +struct VMILegalizeArithSelectPass + : public mlir::pto::impl::VMILegalizeArithSelectBase< + VMILegalizeArithSelectPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILegalizeArithSelectPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector selects; + module.walk([&](arith::SelectOp select) { + if (isVMIValueType(select.getResult().getType()) && + hasScalarI1Condition(select)) + selects.push_back(select); + }); + + for (arith::SelectOp select : llvm::reverse(selects)) { + if (select->getBlock() != nullptr) + rewriteSelectToIf(select); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILegalizeArithSelectPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp b/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp new file mode 100644 index 0000000000..afc5ff04f9 --- /dev/null +++ b/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMIPreAssignmentCombine.cpp - Pre-assignment VMI combines ---------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMIPREASSIGNMENTCOMBINE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult fuseGroupSlotBroadcastLoads(ModuleOp module) { + SmallVector broadcasts; + module.walk([&](VMIGroupBroadcastOp broadcast) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load || !load.getResult().hasOneUse()) + return; + if (load.getNumGroupsAttr().getInt() != + broadcast.getNumGroupsAttr().getInt()) + return; + + if (!isa(broadcast.getResult().getType())) + return; + broadcasts.push_back(broadcast); + }); + + OpBuilder builder(module.getContext()); + for (VMIGroupBroadcastOp broadcast : broadcasts) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load) + continue; + + builder.setInsertionPoint(broadcast); + auto fused = builder.create( + broadcast.getLoc(), broadcast.getResult().getType(), load.getSource(), + load.getOffset(), load.getSourceGroupStride(), + broadcast.getNumGroupsAttr()); + broadcast.getResult().replaceAllUsesWith(fused.getResult()); + broadcast.erase(); + if (load->use_empty()) + load.erase(); + } + return success(); +} + +struct VMIPreAssignmentCombinePass + : pto::impl::VMIPreAssignmentCombineBase { + void runOnOperation() override { + if (failed(fuseGroupSlotBroadcastLoads(getOperation()))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMIPreAssignmentCombinePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp new file mode 100644 index 0000000000..59768b33d7 --- /dev/null +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -0,0 +1,10545 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMIToVPTO.cpp - Convert VMI to physical VPTO IR -------------------===// +//===----------------------------------------------------------------------===// + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMITOVPTO +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +std::optional getX2MemoryDistToken(Type elementType, + StringRef prefix); +std::optional getDenseLaneStrideLoadDistToken(VMIVRegType type); +std::optional getDenseLaneStrideStoreDistToken(VMIVRegType type); + +bool isVMIType(Type type) { return isa(type); } + +bool containsVMIType(Type type) { + if (isVMIType(type)) + return true; + + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); + + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + + return false; +} + +bool hasVMIType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return containsVMIType(type); }); +} + +bool hasVMIType(FunctionType type) { + return hasVMIType(type.getInputs()) || hasVMIType(type.getResults()); +} + +bool hasVMIType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, + [](Attribute element) { return hasVMIType(element); }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return hasVMIType(namedAttr.getValue()); + }); + + return false; +} + +bool hasVMIType(Operation *op) { + if (auto func = dyn_cast(op)) + if (hasVMIType(func.getFunctionType())) + return true; + if (hasVMIType(op->getOperandTypes()) || hasVMIType(op->getResultTypes())) + return true; + for (Region ®ion : op->getRegions()) + for (Block &block : region) + if (hasVMIType(block.getArgumentTypes())) + return true; + for (NamedAttribute attr : op->getAttrs()) + if (hasVMIType(attr.getValue())) + return true; + return false; +} + +bool isVMIOp(Operation *op) { + return op->getName().getStringRef().starts_with("pto.vmi."); +} + +StringRef getTruncFRoundModeForResult(Type resultElementType) { + return pto::isPTOHiFloat8Type(resultElementType) ? "A" : "R"; +} + +StringRef getTruncFRoundMode(VMITruncFOp op, Type resultElementType) { + if (auto roundingAttr = op->getAttrOfType("rounding")) + return roundingAttr.getValue(); + return getTruncFRoundModeForResult(resultElementType); +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return true; +} + +LogicalResult verifyLayoutAssignedVMITypeTree(Operation *op, Type type) { + if (!isLayoutAssignedVMIType(type)) + return op->emitError() << kVMIDiagPassInvariantPrefix + << "vmi-to-vpto requires layout-assigned VMI types"; + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, input))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, result))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedVMITypeTree(op, shapedType.getElementType()); + + return success(); +} + +LogicalResult verifyVMIToVPTOInputAttribute(Operation *op, Attribute attr) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typeAttr.getValue()))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typedAttr.getType()))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, element))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, namedAttr.getValue()))) + return failure(); + } + + return success(); +} + +LogicalResult verifyVMIToVPTOInputTypes(Operation *op) { + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + if (auto func = dyn_cast(op)) { + FunctionType functionType = func.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Type type : block.getArgumentTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (NamedAttribute attr : op->getAttrs()) + if (failed(verifyVMIToVPTOInputAttribute(op, attr.getValue()))) + return failure(); + return success(); +} + +LogicalResult verifyVMIToVPTOInputIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyVMIToVPTOInputTypes(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +static std::optional materializeVPTOToVMI(OpBuilder &builder, + Type resultType, + ValueRange inputs, + Location loc) { + if (!isVMIType(resultType)) + return std::nullopt; + return builder.create(loc, resultType, inputs).getResult(); +} + +static std::optional> +materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + if (!isVMIType(input.getType())) + return std::nullopt; + auto unpackOp = builder.create(loc, resultTypes, input); + return SmallVector(unpackOp->getResults()); +} + +static FailureOr getVMIVRegPhysicalElementType(VMIVRegType type) { + Type elementType = type.getElementType(); + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.hasGroupSlotLaneStride()) + return elementType; + + auto integerType = dyn_cast(elementType); + if (!integerType || !integerType.isUnsigned()) + return failure(); + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + int64_t laneStride = layout.getLaneStride(); + if (elementBits == 0 || laneStride <= 1) + return failure(); + int64_t physicalBits = static_cast(elementBits) * laneStride; + if (physicalBits != 16 && physicalBits != 32) + return failure(); + return IntegerType::get(type.getContext(), physicalBits); +} + +class VMIToVPTOTypeConverter final : public OneToNTypeConverter { +public: + VMIToVPTOTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion( + [](VMIVRegType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + FailureOr physicalElementType = + getVMIVRegPhysicalElementType(type); + if (failed(arity) || failed(physicalElementType)) + return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(*physicalElementType); + if (failed(lanesPerPart)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(VRegType::get(type.getContext(), *lanesPerPart, + *physicalElementType)); + return success(); + }); + addConversion( + [](VMIMaskType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back( + MaskType::get(type.getContext(), type.getGranularity())); + return success(); + }); + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +FailureOr createAllTrueMaskForVReg(Location loc, VRegType vregType, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return rewriter + .create(loc, MaskType::get(ctx, "b8"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 16) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 32) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + return failure(); +} + +FailureOr getMaskTypeForVReg(VRegType vregType, MLIRContext *ctx) { + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return MaskType::get(ctx, "b8"); + if (elementBits == 16) + return MaskType::get(ctx, "b16"); + if (elementBits == 32) + return MaskType::get(ctx, "b32"); + return failure(); +} + +FailureOr createAllTrueMask(Location loc, MaskType maskType, + PatternRewriter &rewriter) { + StringAttr pattern = rewriter.getStringAttr("PAT_ALL"); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), pattern) + .getResult(); + if (maskType.isB16()) + return rewriter.create(loc, MaskType::get(ctx, "b16"), pattern) + .getResult(); + if (maskType.isB32()) + return rewriter.create(loc, MaskType::get(ctx, "b32"), pattern) + .getResult(); + return failure(); +} + +FailureOr createPatternMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr createPrefixMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr> +createRuntimePrefixMask(Location loc, MaskType maskType, Value activeLanes, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + Type scalarType = activeLanes.getType(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b8"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b16"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b32"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + return failure(); +} + +LogicalResult +checkSupportedMaskableVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult elementCapability = capabilities.supportsElementType( + type.getElementType(), VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + return success(); +} + +LogicalResult +checkSupportedTargetElementVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, VMIElementPurpose purpose, + StringRef elementContract, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(type.getElementType(), purpose); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +Value createI32Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { + return rewriter.create(loc, value, 32); +} + +Value createI16Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { + return rewriter.create(loc, value, 16); +} + +FailureOr createPrefixMaskForActiveLanes(Location loc, MaskType maskType, + int64_t activeLanes, + PatternRewriter &rewriter) { + if (activeLanes <= 0) + return createPrefixMask(loc, maskType, "PAT_ALLF", rewriter); + + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return createPrefixMask( + loc, maskType, (Twine("PAT_VL") + Twine(activeLanes)).str(), rewriter); + default: { + FailureOr> dynamicMask = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); + if (failed(dynamicMask)) + return failure(); + return dynamicMask->first; + } + } +} + +Value clampDynamicActiveLanes(Location loc, Value activeLanes, + int64_t maxActiveLanes, + PatternRewriter &rewriter) { + Value activeI32 = rewriter.create( + loc, rewriter.getI32Type(), activeLanes); + Value zeroI32 = createI32Constant(loc, 0, rewriter); + Value nonNegative = rewriter.create(loc, activeI32, zeroI32); + Value maxI32 = createI32Constant(loc, maxActiveLanes, rewriter); + return rewriter.create(loc, nonNegative, maxI32); +} + +Value createPartitionActiveLanes(Location loc, Value activeLanesI32, + int64_t factor, int64_t part, + PatternRewriter &rewriter) { + if (factor == 1) + return activeLanesI32; + int64_t bias = factor - 1 - part; + Value biased = activeLanesI32; + if (bias != 0) + biased = rewriter.create( + loc, biased, createI32Constant(loc, bias, rewriter)); + return rewriter.create( + loc, biased, createI32Constant(loc, factor, rewriter)); +} + +std::optional getPowerOfTwoLog2(int64_t value) { + if (value <= 0 || (value & (value - 1)) != 0) + return std::nullopt; + int64_t log2 = 0; + while (value > 1) { + value >>= 1; + ++log2; + } + return log2; +} + +std::optional getPrefixPattern(int64_t activeLanes, + int64_t lanesPerPart) { + if (activeLanes <= 0) + return std::string("PAT_ALLF"); + if (activeLanes >= lanesPerPart) + return std::string("PAT_ALL"); + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return std::string("PAT_VL") + std::to_string(activeLanes); + default: + return std::nullopt; + } +} + +FailureOr getSingleValue(Operation *op, ValueRange values, + StringRef description, + PatternRewriter &rewriter) { + if (values.size() != 1) { + (void)rewriter.notifyMatchFailure(op, description); + return failure(); + } + return values.front(); +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + return (lhs + rhs - 1) / rhs; +} + +FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +FailureOr getDataChunksInPart(VMIVRegType type, int64_t part) { + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(factor) || failed(lanesPerPart) || part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = + (type.getElementCount() + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +FailureOr getDataFlatPartIndex(VMIVRegType type, int64_t part, + int64_t chunk) { + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor) || part < 0 || part >= *factor || chunk < 0) + return failure(); + + int64_t flatIndex = 0; + for (int64_t currentPart = 0; currentPart < part; ++currentPart) { + FailureOr chunks = getDataChunksInPart(type, currentPart); + if (failed(chunks)) + return failure(); + flatIndex += *chunks; + } + + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks) || chunk >= *chunks) + return failure(); + return flatIndex + chunk; +} + +FailureOr checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("requires assigned layout"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return *lanesPerPart; +} + +FailureOr getVMITypeLayoutFactor(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr.isDeinterleaved() ? layoutAttr.getFactor() : 1; +} + +FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) { + FailureOr physicalElementType = + getVMIVRegPhysicalElementType(vregType); + if (failed(physicalElementType)) + return failure(); + return getDataLanesPerPart(*physicalElementType); + } + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + if (!layout) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + int64_t laneStride = layout.isDense() ? layout.getLaneStride() : 1; + int64_t physicalLanes = + logicalLanesInPart == 0 ? 0 : (logicalLanesInPart - 1) * laneStride + 1; + return ceilDivNonNegative(physicalLanes, *lanesPerPart); +} + +LogicalResult checkFullVMIPhysicalChunks(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason); + +LogicalResult checkSupportedLayoutMaterialization( + const VMITargetCapabilityRegistry &capabilities, Type sourceType, + Type resultType, VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult layoutCapability = + capabilities.supportsLayoutConversion(sourceLayout, resultLayout, Type{}); + if (!layoutCapability.isSupported()) + return fail(layoutCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + Attribute layoutAttr; + if (auto vregType = dyn_cast(type)) + layoutAttr = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layoutAttr = maskType.getLayout(); + else + return fail("requires VMI data or mask type"); + + auto layout = dyn_cast_or_null(layoutAttr); + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous() && layout.getLaneStride() == 1) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +LogicalResult checkCanMaterializeToContiguous(Type type, std::string *reason) { + return succeeded(getContiguousMaterializationPartCount(type, reason)) + ? success() + : failure(); +} + +std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + } + return std::nullopt; +} + +bool isKnownIndexMultipleOf(Value value, int64_t multiple, int depth = 0) { + if (multiple <= 1) + return true; + if (depth > 6) + return false; + if (std::optional constant = getConstantIndexValue(value)) + return *constant % multiple == 0; + + if (auto add = value.getDefiningOp()) + return isKnownIndexMultipleOf(add.getLhs(), multiple, depth + 1) && + isKnownIndexMultipleOf(add.getRhs(), multiple, depth + 1); + if (auto sub = value.getDefiningOp()) + return isKnownIndexMultipleOf(sub.getLhs(), multiple, depth + 1) && + isKnownIndexMultipleOf(sub.getRhs(), multiple, depth + 1); + if (auto mul = value.getDefiningOp()) + return isKnownIndexMultipleOf(mul.getLhs(), multiple, depth + 1) || + isKnownIndexMultipleOf(mul.getRhs(), multiple, depth + 1); + + return false; +} + +FailureOr getStaticMemRefElementCount(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType || !memrefType.hasStaticShape()) + return failure(); + + int64_t elements = 1; + for (int64_t dim : memrefType.getShape()) + elements *= dim; + return elements; +} + +static Type getMemoryElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static bool isPackedByteGroupStore(Type destinationType, VRegType valueType) { + Type destinationElementType = getMemoryElementType(destinationType); + auto destinationIntegerType = + dyn_cast_or_null(destinationElementType); + auto valueIntegerType = dyn_cast(valueType.getElementType()); + return destinationIntegerType && valueIntegerType && + pto::getPTOStorageElemBitWidth(destinationIntegerType) == 8 && + pto::getPTOStorageElemBitWidth(valueIntegerType) == 32; +} + +enum class VMIMemoryValidMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryWriteMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryPermutationKind { + Identity, +}; + +enum class VMIMemoryFallbackDecisionKind { + NotRequired, + RequiredUnavailable, +}; + +struct VMIMemoryLogicalShape { + int64_t elementCount = 0; +}; + +struct VMIMemoryLaneAddressMap { + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + int64_t baseElementOffset = 0; + int64_t elementStride = 1; + int64_t physicalLaneFootprint = 0; + + int64_t getExclusiveEndElement() const { + return baseElementOffset + physicalLaneFootprint * elementStride; + } +}; + +struct VMIMemoryFallbackDecision { + VMIMemoryFallbackDecisionKind kind = + VMIMemoryFallbackDecisionKind::NotRequired; + std::string reason = "not required"; + + static VMIMemoryFallbackDecision notRequired() { return {}; } + + static VMIMemoryFallbackDecision requiredUnavailable(const Twine &reason) { + VMIMemoryFallbackDecision decision; + decision.kind = VMIMemoryFallbackDecisionKind::RequiredUnavailable; + decision.reason = reason.str(); + return decision; + } +}; + +struct VMIMemorySafeReadProof { + bool proven = false; + std::string reason; + std::optional constantOffset; + std::optional staticElementCount; + std::optional laneAddressMap; + int64_t physicalFootprint = 0; +}; + +struct VMIMemoryAccessPlan { + Type baseType; + VMIVRegType valueType; + std::optional constantOffset; + VMIMemoryLogicalShape logicalShape; + VMIMemoryValidMaskKind validMask = VMIMemoryValidMaskKind::AllTrue; + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + std::optional laneAddressMap; + Attribute paddingValue; + VMIMemoryWriteMaskKind writeMask = VMIMemoryWriteMaskKind::AllTrue; + VMIMemorySafeReadProof safeReadProof; + VMICapabilityResult targetCapability; + VMICapabilityResult trueMaskedLoadCapability; + VMICapabilityResult scratchFallbackCapability; + VMICapabilityResult guardedFallbackCapability; + VMIMemoryFallbackDecision fallbackDecision; +}; + +FailureOr +buildContiguousIdentityLaneAddressMap(int64_t constantOffset, + VMIVRegType resultType, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + FailureOr arity = getVMIPhysicalArity(resultType); + if (failed(lanesPerPart) || failed(arity)) + return fail("requires computable physical read footprint"); + + VMIMemoryLaneAddressMap map; + map.baseElementOffset = constantOffset; + map.physicalLaneFootprint = *arity * *lanesPerPart; + return map; +} + +VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, + Value memoryValue = {}) { + auto memrefType = dyn_cast(memoryType); + if (!memrefType || memrefType.getLayout().isIdentity()) + return VMICapabilityResult::supported(); + std::string reason = + (Twine(role) + + " memref layout is non-identity; current VMI memory access plan " + "supports only contiguous identity lane-to-address maps") + .str(); + if (memoryValue && memoryValue.getDefiningOp()) + reason += "; memref.subview requires normalized base/offset/stride " + "lane-to-address planning"; + return VMICapabilityResult::missingCapability(reason); +} + +VMIMemorySafeReadProof +computeSafeFullReadProof(Type sourceType, std::optional constantOffset, + VMIVRegType resultType) { + VMIMemorySafeReadProof proof; + proof.constantOffset = constantOffset; + + auto fail = [&](const Twine &message) { + proof.proven = false; + proof.reason = message.str(); + return proof; + }; + + if (!constantOffset) + return fail("requires constant index offset"); + + FailureOr staticElements = getStaticMemRefElementCount(sourceType); + if (failed(staticElements)) + return fail("requires statically shaped memref source"); + int64_t elements = *staticElements; + proof.staticElementCount = elements; + + if (*constantOffset < 0) + return fail("requires non-negative offset"); + + std::string addressMapReason; + FailureOr addressMap = + buildContiguousIdentityLaneAddressMap(*constantOffset, resultType, + &addressMapReason); + if (failed(addressMap)) + return fail(addressMapReason); + proof.laneAddressMap = *addressMap; + + proof.physicalFootprint = addressMap->physicalLaneFootprint; + if (addressMap->getExclusiveEndElement() > elements) + return fail(Twine("full physical read footprint [") + + Twine(addressMap->baseElementOffset) + ", " + + Twine(addressMap->getExclusiveEndElement()) + + ") exceeds static memref element count " + Twine(elements)); + + proof.proven = true; + return proof; +} + +VMIMemoryAccessPlan +buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value source, Type sourceType, VMIVRegType resultType, + std::optional constantOffset, + VMIMemoryValidMaskKind validMask) { + VMIMemoryAccessPlan plan; + plan.baseType = sourceType; + plan.valueType = resultType; + plan.constantOffset = constantOffset; + plan.logicalShape.elementCount = resultType.getElementCount(); + plan.validMask = validMask; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; + plan.safeReadProof = + computeSafeFullReadProof(sourceType, constantOffset, resultType); + plan.laneAddressMap = plan.safeReadProof.laneAddressMap; + plan.targetCapability = + capabilities.supportsDirectMemory(sourceType, "source"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = + requireIdentityMemRefLayout(sourceType, "source", source); + if (validMask == VMIMemoryValidMaskKind::ExplicitMask) + plan.trueMaskedLoadCapability = + capabilities.supportsTrueMaskedLoad(sourceType, resultType, Type{}); + plan.scratchFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::ScratchMemory); + plan.guardedFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::GuardedControlFlow); + return plan; +} + +VMIMemoryAccessPlan +buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value destination, Type destinationType, + VMIVRegType valueType, VMIMemoryWriteMaskKind writeMask) { + VMIMemoryAccessPlan plan; + plan.baseType = destinationType; + plan.valueType = valueType; + plan.logicalShape.elementCount = valueType.getElementCount(); + plan.validMask = VMIMemoryValidMaskKind::AllTrue; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = writeMask; + plan.targetCapability = + capabilities.supportsDirectMemory(destinationType, "destination"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = requireIdentityMemRefLayout( + destinationType, "destination", destination); + return plan; +} + +void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { + std::string maskedLoadReason; + if (plan.validMask == VMIMemoryValidMaskKind::ExplicitMask && + !plan.trueMaskedLoadCapability.isSupported()) + maskedLoadReason = + (Twine("; ") + plan.trueMaskedLoadCapability.reason).str(); + std::string scratchReason; + if (!plan.scratchFallbackCapability.isSupported()) + scratchReason = (Twine("; ") + plan.scratchFallbackCapability.reason).str(); + std::string guardedReason; + if (!plan.guardedFallbackCapability.isSupported()) + guardedReason = (Twine("; ") + plan.guardedFallbackCapability.reason).str(); + plan.fallbackDecision = VMIMemoryFallbackDecision::requiredUnavailable( + Twine("partial/tail read needs a scratch, guarded, or true " + "masked/non-faulting load fallback, but no such fallback resource " + "plan is implemented") + + maskedLoadReason + scratchReason + guardedReason); +} + +FailureOr verifyFullOrSafeReadVRegChunks(Operation *op, + VMIVRegType type, + Type sourceType, Value offset, + PatternRewriter &rewriter) { + std::string fullChunkReason; + FailureOr lanesPerPart = + checkFullDataPhysicalChunks(type, &fullChunkReason); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + + VMIMemorySafeReadProof safeReadProof = + computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); + if (safeReadProof.proven) { + lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + } + + lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + + (void)rewriter.notifyMatchFailure( + op, Twine("memory lowering ") + fullChunkReason + + "; safe full-read proof failed: " + safeReadProof.reason); + return failure(); +} + +LogicalResult +checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value source, Type sourceType, + std::optional constantOffset, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMIMemoryAccessPlan accessPlan = + buildReadAccessPlan(capabilities, source, sourceType, type, + constantOffset, VMIMemoryValidMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + if (getDenseLaneStrideLoadDistToken(type)) + return success(); + + if (failed(getDataLanesPerPart(type.getElementType()))) + return fail("requires element type with known physical lane width"); + return success(); +} + +LogicalResult checkSupportedDeinterleaveLoadShape( + const VMITargetCapabilityRegistry &capabilities, + VMIDeinterleaveLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lowType = cast(op.getLow().getType()); + auto highType = cast(op.getHigh().getType()); + VMILayoutAttr lowLayout = lowType.getLayoutAttr(); + VMILayoutAttr highLayout = highType.getLayoutAttr(); + if (!lowLayout || !highLayout || !lowLayout.isContiguous() || + !highLayout.isContiguous()) + return fail("requires assigned contiguous low/high result layouts"); + if (lowType.getElementCount() != highType.getElementCount() || + lowType.getElementType() != highType.getElementType()) + return fail("requires matching low/high result shape and element type"); + if (!getX2MemoryDistToken(lowType.getElementType(), "DINTLV")) + return fail("requires 8/16/32-bit element type for vldsx2 DINTLV"); + + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), lowType, + getConstantIndexValue(op.getOffset()), VMIMemoryValidMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(lowType, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + return success(); +} + +LogicalResult +checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value destination, + Type destinationType, std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, type, + VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + if (getDenseLaneStrideStoreDistToken(type)) + return success(); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("requires assigned layout"); + if (failed(getDataLanesPerPart(type.getElementType()))) + return fail("requires known physical lanes per part"); + if (layout.isContiguous() && layout.getLaneStride() == 1) + return success(); + + std::string materializationReason; + if (succeeded(checkCanMaterializeToContiguous(type, &materializationReason))) + return success(); + return fail(Twine("partial/tail store requires contiguous layout or " + "deinterleaved layout that can materialize to contiguous; " + "value ") + + fullChunkReason + ", materialization " + materializationReason); +} + +LogicalResult checkSupportedInterleaveStoreShape( + const VMITargetCapabilityRegistry &capabilities, + VMIInterleaveStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lowType = cast(op.getLow().getType()); + auto highType = cast(op.getHigh().getType()); + VMILayoutAttr lowLayout = lowType.getLayoutAttr(); + VMILayoutAttr highLayout = highType.getLayoutAttr(); + if (!lowLayout || !highLayout || !lowLayout.isContiguous() || + !highLayout.isContiguous()) + return fail("requires assigned contiguous low/high input layouts"); + if (lowType.getElementCount() != highType.getElementCount() || + lowType.getElementType() != highType.getElementType()) + return fail("requires matching low/high input shape and element type"); + if (!getX2MemoryDistToken(lowType.getElementType(), "INTLV")) + return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); + + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, op.getDestination(), + op.getDestination().getType(), lowType, + VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (failed(checkSupportedMaskableVReg(capabilities, lowType, reason))) + return failure(); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(lowType, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + return success(); +} + +FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, int64_t groupSize, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("requires assigned contiguous layout"); + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("requires derived group size to evenly divide logical lane " + "count"); + if (groupSize % *lanesPerPart != 0) + return fail("currently requires group size to be a multiple of physical " + "lanes per part"); + return success(); +} + +LogicalResult +checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + + if (resultLayout.isContiguous()) { + if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), + op.getSource().getType(), std::nullopt, + reason))) + return failure(); + return checkSupportedGroupChunkShape(resultType, *groupSize, reason); + } + + if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && + resultType.getElementType().isF32()) { + VMILayoutSupport supports; + if (failed(supports.getGroupLoadSupport(capabilities, op, reason))) + return failure(); + return success(); + } + + return fail("requires contiguous layout or deinterleaved block8 f32 layout"); +} + +LogicalResult checkSupportedGroupSlotLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, + std::string *reason) { + VMILayoutSupport supports; + if (failed(supports.getGroupSlotLoadSupport(capabilities, op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedGroupBroadcastLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastLoadOp op, + std::string *reason) { + VMILayoutSupport supports; + if (failed(supports.getGroupBroadcastLoadSupport(capabilities, op, reason))) + return failure(); + return success(); +} + +LogicalResult +checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (layout && layout.isGroupSlots()) { + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMIMemoryAccessPlan accessPlan = buildWriteAccessPlan( + capabilities, op.getDestination(), op.getDestination().getType(), + valueType, VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + VMILayoutSupport supports; + if (failed(supports.getGroupSlotsStoreSupport(capabilities, op, reason))) + return failure(); + return success(); + } + + FailureOr groupSize = getGroupSizeFromNumGroups( + valueType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + if (failed(checkSupportedStoreShape(capabilities, valueType, + op.getDestination(), + op.getDestination().getType(), reason))) + return failure(); + return checkSupportedGroupChunkShape(valueType, *groupSize, reason); +} + +LogicalResult +checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIMaskedLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (accessPlan.safeReadProof.proven) + return success(); + requireUnavailableReadFallback(accessPlan); + return fail(Twine("partial/tail masked_load requires statically safe " + "full-read footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason); +} + +LogicalResult +checkSupportedGatherShape(const VMITargetCapabilityRegistry &capabilities, + VMIGatherOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!resultLayout || !indicesLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, indices, passthru, and mask " + "layouts"); + if (!resultLayout.isContiguous() || !indicesLayout.isContiguous() || + !passthruLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous result, indices, passthru, and mask " + "layouts"); + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) + return fail(sourceCapability.reason); + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.isSigned()) + return fail("requires signless or unsigned integer indices"); + bool isU16Gather = resultBits == 16 && indexElementType.isUnsigned() && + indexElementType.getWidth() == 16 && + maskType.getGranularity() == "b16"; + bool isB32Gather = resultBits == 32 && indexElementType.getWidth() == 32 && + maskType.getGranularity() == "b32"; + if (!isU16Gather && !isB32Gather) + return fail("requires either 32-bit results with 32-bit indices and b32 " + "mask, or ui16 results with ui16 indices and b16 mask"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(indicesArity) || failed(passthruArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*resultArity != *indicesArity || *resultArity != *passthruArity || + *resultArity != *maskArity) + return fail("requires result, indices, passthru, and mask to have the " + "same physical arity"); + + if (isB32Gather) { + std::string resultReason; + std::string indicesReason; + std::string passthruReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) + return fail(Twine("result requires full physical chunks; ") + + resultReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + } else if (*resultArity != 1) { + return fail("ui16 gather currently supports one physical chunk"); + } + + return success(); +} + +LogicalResult +checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, + VMIScatterOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !indicesLayout || !maskLayout) + return fail("requires assigned value, indices, and mask layouts"); + if (!valueLayout.isContiguous() || !indicesLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous value, indices, and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vscatter", + "pto.vscatter writes only UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + if (pto::getPTOStorageElemBitWidth(valueType.getElementType()) != 32) + return fail("currently requires 32-bit value element type so physical " + "index and value lane counts match pto.vscatter"); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return fail("requires signless or unsigned 32-bit indices"); + if (maskType.getGranularity() != "b32") + return fail("requires b32 mask granularity"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(indicesArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*valueArity != *indicesArity || *valueArity != *maskArity) + return fail("requires value, indices, and mask to have the same physical " + "arity"); + + std::string valueReason; + std::string indicesReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(valueType, &valueReason))) + return fail(Twine("value requires full physical chunks; ") + valueReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + + return success(); +} + +LogicalResult +checkSupportedStrideStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIStrideStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + if (!valueLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous value and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vsstb", + "pto.vsstb writes only UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + if (failed(checkSupportedStoreShape(capabilities, valueType, + op.getDestination(), + op.getDestination().getType(), reason))) + return failure(); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*valueArity != 1 || *maskArity != 1) + return fail("currently supports one physical value/mask chunk"); + return success(); +} + +LogicalResult +checkSupportedStrideLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIStrideLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!resultLayout || !maskLayout) + return fail("requires assigned result and mask layouts"); + if (!resultLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous result and mask layouts"); + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vsldb", + "pto.vsldb reads only UB"); + if (!sourceCapability.isSupported()) + return fail(sourceCapability.reason); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*resultArity != 1 || *maskArity != 1) + return fail("currently supports one physical result/mask chunk"); + return success(); +} + +Value stripMaskMaterialization(Value value) { + while (true) { + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + return value; + } +} + +bool isStaticAllActiveMask(Value mask, int64_t expectedLanes, + std::string *reason = nullptr) { + mask = stripMaskMaterialization(mask); + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return false; + }; + + if (auto createMask = mask.getDefiningOp()) { + auto activeConstant = + createMask.getActiveLanes().getDefiningOp(); + if (!activeConstant) + return fail("create_mask active_lanes is dynamic"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("create_mask active_lanes is not an integer constant"); + return activeAttr.getInt() >= expectedLanes + ? true + : fail("create_mask active_lanes is smaller than the logical " + "lane count"); + } + + if (auto constantMask = mask.getDefiningOp()) { + auto denseAttr = dyn_cast(constantMask.getValue()); + if (!denseAttr) + return fail("constant_mask is not a dense integer mask"); + if (denseAttr.getNumElements() != expectedLanes) + return fail("constant_mask element count does not match the logical " + "lane count"); + auto values = denseAttr.getValues(); + for (bool value : values) + if (!value) + return fail("constant_mask contains an inactive lane"); + return true; + } + + return fail("mask is not a static all-active create_mask or constant_mask"); +} + +LogicalResult +checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIExpandLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string maskReason; + bool staticAllActive = isStaticAllActiveMask( + op.getMask(), resultType.getElementCount(), &maskReason); + + std::string fullChunkReason; + if (staticAllActive && + succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (staticAllActive && accessPlan.safeReadProof.proven) + return success(); + + std::string allActivePathReason; + if (!staticAllActive) { + allActivePathReason = + maskReason.empty() ? "requires static all-active mask" : maskReason; + } else { + requireUnavailableReadFallback(accessPlan); + allActivePathReason = + (Twine("requires full physical chunks or statically safe full-read " + "footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason) + .str(); + } + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) { + if (!isa(op.getSource().getType())) + return fail(Twine("runtime-mask path ") + sourceCapability.reason + + "; all-active path " + allActivePathReason); + return fail(Twine("runtime-mask path ") + sourceCapability.reason); + } + if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("runtime-mask path currently requires 32-bit result element " + "type so prefix indices and gather result lane counts match"); + if (maskType.getGranularity() != "b32") + return fail("runtime-mask path requires b32 mask granularity"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(passthruArity) || failed(maskArity)) + return fail("runtime-mask path requires computable physical arity"); + if (*resultArity != 1 || *passthruArity != 1 || *maskArity != 1) + return fail("runtime-mask path currently supports only one physical " + "chunk because prefix indices must not reset across chunks"); + + std::string passthruReason; + std::string maskFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("runtime-mask result requires full physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("runtime-mask passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("runtime-mask mask requires full physical chunks; ") + + maskFullReason); + + return success(); +} + +LogicalResult +checkSupportedMaskedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType valueType, VMIMaskType maskType, + Value destination, Type destinationType, + std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, + valueType, VMIMemoryWriteMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + std::string valueReason; + std::string maskReason; + if (succeeded(checkFullDataPhysicalChunks(valueType, &valueReason)) && + succeeded(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity) || *valueArity != *maskArity) + return fail("requires matching value/mask physical arity"); + + if (valueLayout.hasDenseLaneStride()) { + VMILayoutSupport supports; + auto contiguousValueType = + VMIVRegType::get(valueType.getContext(), valueType.getElementCount(), + valueType.getElementType(), + VMILayoutAttr::getContiguous(valueType.getContext())); + auto contiguousMaskType = + VMIMaskType::get(maskType.getContext(), maskType.getElementCount(), + maskType.getGranularity(), + VMILayoutAttr::getContiguous(maskType.getContext())); + if (succeeded(supports.canFoldContiguousMaskedStoreMaterialization( + valueType, maskType, contiguousValueType, contiguousMaskType, + reason))) + return success(); + } + + std::string valueMaterializationReason; + FailureOr valueParts = getContiguousMaterializationPartCount( + valueType, &valueMaterializationReason); + if (failed(valueParts)) + return fail(Twine("value cannot materialize to contiguous; value ") + + valueReason + ", materialization " + + valueMaterializationReason); + + std::string maskMaterializationReason; + FailureOr maskParts = getContiguousMaterializationPartCount( + maskType, &maskMaterializationReason); + if (failed(maskParts)) + return fail(Twine("mask cannot materialize to contiguous; mask ") + + maskReason + ", materialization " + maskMaterializationReason); + if (*valueParts != *maskParts) + return fail( + "requires value/mask contiguous materialization arity to match"); + return success(); +} + +FailureOr getContiguousActiveDataLanes(VMIVRegType vmiType, + int64_t chunk) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + int64_t remaining = vmiType.getElementCount() - chunk * *lanesPerPart; + return std::clamp(remaining, 0, *lanesPerPart); +} + +FailureOr getActiveDataLanesInPhysicalChunk(VMIVRegType vmiType, + int64_t chunk) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + int64_t active = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(vmiType, /*part=*/0, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++active; + } + return active; +} + +FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, + int64_t chunk, VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return createAllTrueMaskForVReg(loc, vregType, rewriter); + + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return failure(); + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, *maskType, createI32Constant(loc, *activeLanes, rewriter), rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, + int64_t chunk, Value userMask, + VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return userMask; + + auto maskType = dyn_cast(userMask.getType()); + if (!maskType) + return failure(); + FailureOr tailMask = + createContiguousStoreMask(loc, vmiType, chunk, vregType, rewriter); + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(tailMask) || failed(allTrue)) + return failure(); + return rewriter.create(loc, maskType, userMask, *tailMask, *allTrue) + .getResult(); +} + +FailureOr createDenseLaneStrideStorePredicate( + Location loc, VMIVRegType vmiType, int64_t chunk, Value userMask, + StringRef targetGranularity, PatternRewriter &rewriter) { + auto sourceMaskType = dyn_cast(userMask.getType()); + if (!sourceMaskType) + return failure(); + auto targetMaskType = MaskType::get(rewriter.getContext(), targetGranularity); + Value compactMask = userMask; + VMILayoutAttr layout = vmiType.getLayoutAttr(); + if (!layout) + return failure(); + + auto lower = rewriter.getStringAttr("LOWER"); + StringRef sourceGranularity = sourceMaskType.getGranularity(); + if (layout.getLaneStride() == 2) { + compactMask = + rewriter.create(loc, targetMaskType, compactMask, lower) + .getResult(); + } else if (layout.getLaneStride() == 4 && sourceGranularity == "b8" && + targetGranularity == "b32") { + auto b16MaskType = MaskType::get(rewriter.getContext(), "b16"); + compactMask = + rewriter.create(loc, b16MaskType, compactMask, lower) + .getResult(); + compactMask = + rewriter.create(loc, targetMaskType, compactMask, lower) + .getResult(); + } else { + return failure(); + } + + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(vmiType, chunk); + FailureOr maskLanes = getMaskLanesPerPart(targetGranularity); + if (failed(activeLanes) || failed(maskLanes)) + return failure(); + if (*activeLanes == *maskLanes) + return compactMask; + + FailureOr tailMask = createPrefixMaskForActiveLanes( + loc, targetMaskType, *activeLanes, rewriter); + FailureOr allTrue = createAllTrueMask(loc, targetMaskType, rewriter); + if (failed(tailMask) || failed(allTrue)) + return failure(); + return rewriter + .create(loc, targetMaskType, compactMask, *tailMask, *allTrue) + .getResult(); +} + +FailureOr> +computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector sourceFlatIndices; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding)) + return fail("failed to classify result padding lanes"); + if (*padding) + continue; + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + if (sourcePhysical->lane != lane) + return fail("requires same-lane physical chunks"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + continue; + } + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + } + + if (!sourcePart || !sourceChunk) + return fail("requires at least one logical lane per result chunk"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + sourceFlatIndices.push_back(*sourceFlatIndex); + } + } + + return sourceFlatIndices; +} + +struct ShuffleVselrPlan { + int64_t sourceFlatIndex = 0; + int64_t baseLane = 0; + bool descending = false; +}; + +FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + if (!llvm::all_of(indices, [](int64_t index) { return index == 0; })) + return fail("requires every result lane to select source lane 0"); + + auto sourceType = cast(op.getSource().getType()); + FailureOr sourceLane = + mapLogicalLaneToPhysical(sourceType, 0); + if (failed(sourceLane)) + return fail("failed to map source lane 0"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, sourceLane->part, sourceLane->chunk); + if (failed(sourceFlatIndex)) + return fail("source lane 0 part range is out of bounds"); + return *sourceFlatIndex; +} + +FailureOr> +computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { + auto fail = + [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector plans; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + std::optional baseLane; + std::optional descending; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding) || *padding) + return fail("requires full physical result chunks"); + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + baseLane = sourcePhysical->lane; + continue; + } + + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + + int64_t ascExpected = *baseLane + lane; + int64_t descExpected = *baseLane - lane; + bool asc = sourcePhysical->lane == ascExpected; + bool desc = sourcePhysical->lane == descExpected; + if (!asc && !desc) + return fail("requires ASC or DESC affine source lane indices"); + + bool laneDescending = desc && !asc; + if (!descending) { + descending = laneDescending; + continue; + } + if (*descending != laneDescending) + return fail("requires one index order per result chunk"); + } + + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + plans.push_back(ShuffleVselrPlan{*sourceFlatIndex, *baseLane, + descending.value_or(false)}); + } + } + + return plans; +} + +struct ConstantMaskChunkMaterialization { + SmallVector activeLanes; +}; + +FailureOr> +computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr) + return fail("only dense integer mask constants are supported"); + + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + auto boolValues = denseAttr.getValues(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + materialization.activeLanes.push_back(boolValues[*logicalLane] ? 1 : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + +FailureOr> +computeGroupMaskMaterializationForType(VMICreateGroupMaskOp op, + VMIMaskType resultVMIType, + std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant) + return fail("requires constant active_elems_per_group"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("active_elems_per_group must be an integer constant"); + + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + resultVMIType.getElementCount() != numGroups * groupSize) + return fail("requires result lane count to match num_groups * group_size"); + + int64_t activeElems = activeAttr.getInt(); + if (activeElems < 0) + activeElems = 0; + if (activeElems > groupSize) + activeElems = groupSize; + + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + int64_t laneInGroup = *logicalLane % groupSize; + materialization.activeLanes.push_back(laneInGroup < activeElems ? 1 + : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + +FailureOr> +computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { + return computeGroupMaskMaterializationForType( + op, cast(op.getResult().getType()), reason); +} + +FailureOr> materializeDynamicContiguousGroupMask( + VMICreateGroupMaskOp op, Value activeElemsPerGroup, + VMIMaskType contiguousVMIType, TypeRange resultTypes, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + VMILayoutAttr layout = contiguousVMIType.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("dynamic create_group_mask requires contiguous seed layout"); + if (contiguousVMIType.getGranularity() != "b32") + return fail("dynamic create_group_mask currently requires b32 " + "granularity"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + contiguousVMIType.getElementCount() != numGroups * groupSize) + return fail("dynamic create_group_mask requires result lane count to " + "match num_groups * group_size"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(contiguousVMIType.getGranularity()); + FailureOr arity = getVMIPhysicalArity(contiguousVMIType); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("dynamic create_group_mask requires computable physical " + "mask chunks"); + if (static_cast(resultTypes.size()) != *arity) + return fail("dynamic create_group_mask physical result count mismatch"); + if (groupSize > *lanesPerPart || (*lanesPerPart % groupSize) != 0) + return fail("dynamic create_group_mask currently requires group_size to " + "divide one physical b32 predicate chunk"); + + std::optional shift = getPowerOfTwoLog2(groupSize); + if (!shift) + return fail("dynamic create_group_mask currently requires power-of-two " + "group_size"); + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type i32 = rewriter.getI32Type(); + auto indexVectorType = VRegType::get(ctx, *lanesPerPart, i32); + Value activeI32 = + clampDynamicActiveLanes(loc, activeElemsPerGroup, groupSize, rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto maskType = dyn_cast(resultType); + if (!maskType || !maskType.isB32()) + return fail("dynamic create_group_mask result must be b32 mask"); + + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(allMask)) + return fail("failed to create dynamic create_group_mask all mask"); + + Value zero = createI32Constant(loc, 0, rewriter); + Value lane = + rewriter.create(loc, indexVectorType, zero, StringAttr{}) + .getResult(); + + Value col = lane; + if (groupSize != *lanesPerPart) { + Value shiftScalar = createI16Constant(loc, *shift, rewriter); + Value group = rewriter + .create(loc, indexVectorType, lane, + shiftScalar, *allMask) + .getResult(); + Value groupBase = rewriter + .create(loc, indexVectorType, group, + shiftScalar, *allMask) + .getResult(); + col = rewriter + .create(loc, indexVectorType, lane, groupBase, *allMask) + .getResult(); + } + + results.push_back(rewriter + .create(loc, maskType, col, activeI32, + *allMask, + rewriter.getStringAttr("lt")) + .getResult()); + } + + return results; +} + +std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { + bool seenInactive = false; + int64_t activeCount = 0; + for (int8_t active : activeLanes) { + if (active) { + if (seenInactive) + return std::nullopt; + ++activeCount; + continue; + } + seenInactive = true; + } + return activeCount; +} + +FailureOr materializePrefixMask(Location loc, MaskType maskType, + int64_t activeLanes, + int64_t lanesPerPart, + PatternRewriter &rewriter) { + std::optional pattern = + getPrefixPattern(activeLanes, lanesPerPart); + if (pattern) + return createPatternMask(loc, maskType, *pattern, rewriter); + + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr materializeConstantMaskChunk(Location loc, MaskType maskType, + ArrayRef activeLanes, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || + static_cast(activeLanes.size()) != *lanesPerPart) + return failure(); + + if (std::optional prefixCount = + getPrefixActiveLaneCount(activeLanes)) + return materializePrefixMask(loc, maskType, *prefixCount, *lanesPerPart, + rewriter); + + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(allTrue)) + return failure(); + + Value result; + int64_t lane = 0; + while (lane < *lanesPerPart) { + while (lane < *lanesPerPart && !activeLanes[lane]) + ++lane; + if (lane >= *lanesPerPart) + break; + + int64_t runBegin = lane; + while (lane < *lanesPerPart && activeLanes[lane]) + ++lane; + int64_t runEnd = lane; + + FailureOr prefixEnd = + materializePrefixMask(loc, maskType, runEnd, *lanesPerPart, rewriter); + if (failed(prefixEnd)) + return failure(); + + Value runMask = *prefixEnd; + if (runBegin != 0) { + FailureOr prefixBegin = materializePrefixMask( + loc, maskType, runBegin, *lanesPerPart, rewriter); + if (failed(prefixBegin)) + return failure(); + Value notPrefixBegin = + rewriter.create(loc, maskType, *prefixBegin, *allTrue) + .getResult(); + runMask = rewriter + .create(loc, maskType, *prefixEnd, notPrefixBegin, + *allTrue) + .getResult(); + } + + if (!result) { + result = runMask; + continue; + } + result = rewriter.create(loc, maskType, result, runMask, *allTrue) + .getResult(); + } + + if (result) + return result; + return materializePrefixMask(loc, maskType, 0, *lanesPerPart, rewriter); +} + +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter); + +Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return baseOffset; + Value delta = rewriter.create(loc, laneOffset); + return rewriter.create(loc, baseOffset, delta).getResult(); +} + +Value createGroupChunkOffset(Location loc, Value baseOffset, Value rowStride, + int64_t group, int64_t inGroupLaneOffset, + PatternRewriter &rewriter) { + Value offset = baseOffset; + if (group != 0) { + Value groupIndex = rewriter.create(loc, group); + Value rowOffset = + rewriter.create(loc, rowStride, groupIndex).getResult(); + offset = rewriter.create(loc, offset, rowOffset).getResult(); + } + return createChunkOffset(loc, offset, inGroupLaneOffset, rewriter); +} + +LogicalResult checkContiguousFullGroupChunks( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t *lanesPerPart, + int64_t *groupCount, int64_t *chunksPerGroup, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("group op requires contiguous VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group op requires full physical chunks"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group op requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group op requires derived group size to evenly divide lane " + "count"); + if (groupSize % *lanes != 0) + return fail("group op currently requires group size to be a multiple of " + "physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = type.getElementCount() / groupSize; + *chunksPerGroup = groupSize / *lanes; + return success(); +} + +LogicalResult checkFullGroupSlotSourceShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t numGroups, + int64_t *lanesPerPart, int64_t *groupCount, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getNumGroups() != numGroups) + return fail("group slot op requires matching num_groups VMI layout"); + if (type.getElementCount() != numGroups) + return fail("group slot op requires one logical lane per group"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group slot op requires known physical lanes per part"); + if (groupSize <= 0) + return fail("group slot op requires positive derived group size"); + if (*lanes % groupSize != 0 && groupSize % *lanes != 0) + return fail("group slot op requires group size to divide or be a " + "multiple of physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = numGroups; + return success(); +} + +LogicalResult checkFullGroupBroadcastResultShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t lanesPerPart, + int64_t *layoutFactor, int64_t *groupCount, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("group_broadcast result requires assigned VMI layout"); + if (layout.isGroupSlots()) + return fail("group_broadcast result requires a dense VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group_broadcast result requires full physical chunks"); + FailureOr resultLanes = getDataLanesPerPart(type.getElementType()); + if (failed(resultLanes) || *resultLanes != lanesPerPart) + return fail("group_broadcast result requires matching physical lanes"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group_broadcast result requires derived group size to evenly " + "divide lane count"); + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("group_broadcast result requires known layout factor"); + + if (*factor == 1) { + if (lanesPerPart % groupSize != 0 && groupSize % lanesPerPart != 0) + return fail("group_broadcast contiguous result requires group size to " + "divide or be a multiple of physical lanes per part"); + } else { + bool blockFragmentSmallGroup = + layout.isDeinterleaved() && layout.getBlockElems() > 1 && + groupSize < lanesPerPart && lanesPerPart % layout.getBlockElems() == 0; + bool deinterleavedSmallGroup = + layout.isDeinterleaved() && layout.getBlockElems() == 1 && + groupSize < lanesPerPart && groupSize >= *factor && + groupSize % *factor == 0 && lanesPerPart % (groupSize / *factor) == 0; + int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; + if (!blockFragmentSmallGroup && !deinterleavedSmallGroup && + (groupSize < lanesPerPart || + groupSize % logicalSpanPerResultChunk != 0)) + return fail("group_broadcast deinterleaved result requires every " + "physical result chunk to stay within one logical group"); + } + + *layoutFactor = *factor; + *groupCount = type.getElementCount() / groupSize; + return success(); +} + +FailureOr createZeroVector(Location loc, VRegType type, + PatternRewriter &rewriter) { + FailureOr zero = + createScalarOffsetConstant(loc, type.getElementType(), 0, rewriter); + FailureOr mask = createAllTrueMaskForVReg(loc, type, rewriter); + if (failed(zero) || failed(mask)) + return failure(); + return rewriter + .create(loc, type, *zero, *mask, + /*position=*/nullptr) + .getResult(); +} + +FailureOr createLaneRangeMask(Location loc, MaskType maskType, + int64_t begin, int64_t end, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || begin < 0 || begin > end || end > *lanesPerPart) + return failure(); + SmallVector active(*lanesPerPart, 0); + for (int64_t lane = begin; lane < end; ++lane) + active[lane] = 1; + return materializeConstantMaskChunk(loc, maskType, active, rewriter); +} + +FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, + int64_t groupSize, + int64_t baseGroupSlot, + PatternRewriter &rewriter) { + int64_t lanesPerPart = indexType.getElementCount(); + FailureOr baseScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), baseGroupSlot, rewriter); + FailureOr maskType = + getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); + if (failed(baseScalar) || failed(maskType) || failed(allMask)) + return failure(); + Value result = rewriter + .create(loc, indexType, *baseScalar, *allMask, + /*position=*/nullptr) + .getResult(); + if (groupSize >= lanesPerPart) + return result; + if (lanesPerPart % groupSize != 0) + return failure(); + + int64_t groupsPerChunk = lanesPerPart / groupSize; + for (int64_t localGroup = 1; localGroup < groupsPerChunk; ++localGroup) { + FailureOr groupScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), baseGroupSlot + localGroup, rewriter); + FailureOr laneMask = + createLaneRangeMask(loc, *maskType, localGroup * groupSize, + (localGroup + 1) * groupSize, rewriter); + if (failed(groupScalar) || failed(laneMask)) + return failure(); + Value splat = rewriter + .create(loc, indexType, *groupScalar, *allMask, + /*position=*/nullptr) + .getResult(); + result = rewriter.create(loc, indexType, splat, result, *laneMask) + .getResult(); + } + return result; +} + +FailureOr createMappedGroupSlotIndexVector( + Location loc, VMIVRegType resultVMIType, int64_t part, int64_t chunk, + VRegType indexType, int64_t groupSize, int64_t slots, int64_t &sourceChunk, + PatternRewriter &rewriter) { + if (groupSize <= 0 || slots <= 0) + return failure(); + + int64_t lanesPerPart = indexType.getElementCount(); + SmallVector slotByLane; + slotByLane.reserve(lanesPerPart); + std::optional resolvedSourceChunk; + for (int64_t lane = 0; lane < lanesPerPart; ++lane) { + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return failure(); + int64_t group = *logicalLane / groupSize; + int64_t candidateSourceChunk = group / slots; + if (resolvedSourceChunk && *resolvedSourceChunk != candidateSourceChunk) + return failure(); + resolvedSourceChunk = candidateSourceChunk; + slotByLane.push_back(group % slots); + } + if (!resolvedSourceChunk) + return failure(); + sourceChunk = *resolvedSourceChunk; + + FailureOr baseScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), slotByLane.front(), rewriter); + FailureOr maskType = + getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); + if (failed(baseScalar) || failed(maskType) || failed(allMask)) + return failure(); + + Value result = rewriter + .create(loc, indexType, *baseScalar, *allMask, + /*position=*/nullptr) + .getResult(); + int64_t rangeBegin = 0; + while (rangeBegin < lanesPerPart) { + int64_t slot = slotByLane[rangeBegin]; + int64_t rangeEnd = rangeBegin + 1; + while (rangeEnd < lanesPerPart && slotByLane[rangeEnd] == slot) + ++rangeEnd; + if (rangeBegin != 0 || slot != slotByLane.front()) { + FailureOr slotScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), slot, rewriter); + FailureOr laneMask = + createLaneRangeMask(loc, *maskType, rangeBegin, rangeEnd, rewriter); + if (failed(slotScalar) || failed(laneMask)) + return failure(); + Value splat = rewriter + .create(loc, indexType, *slotScalar, *allMask, + /*position=*/nullptr) + .getResult(); + result = rewriter.create(loc, indexType, splat, result, *laneMask) + .getResult(); + } + rangeBegin = rangeEnd; + } + return result; +} + +template +FailureOr reduceVcgSlotsToLane0(Location loc, Value reduced, + VRegType resultType, Value firstLaneMask, + PatternRewriter &rewriter) { + unsigned indexBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return failure(); + + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = VRegType::get( + rewriter.getContext(), resultType.getElementCount(), indexElementType); + Value accumulator = reduced; + for (int64_t slot = 1; slot < 8; ++slot) { + FailureOr slotIndex = createGroupSlotIndexVector( + loc, indexType, resultType.getElementCount(), slot, rewriter); + if (failed(slotIndex)) + return failure(); + Value selected = + rewriter.create(loc, resultType, reduced, *slotIndex) + .getResult(); + accumulator = rewriter + .create(loc, resultType, selected, + accumulator, firstLaneMask) + .getResult(); + } + return accumulator; +} + +std::optional getX2MemoryDistToken(Type elementType, + StringRef prefix) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return std::nullopt; + return (Twine(prefix) + "_B" + Twine(elementBits)).str(); +} + +std::optional getDenseLaneStrideLoadDistToken(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && + (elementBits == 8 || elementBits == 16 || elementBits == 32)) + return (Twine("UNPK_B") + Twine(elementBits)).str(); + if (layout.getLaneStride() == 4 && elementBits == 8) + return std::string("UNPK4"); + return std::nullopt; +} + +std::optional getDenseLaneStrideStoreDistToken(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return std::string("PK_B16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return std::string("PK_B32"); + if (layout.getLaneStride() == 2 && elementBits == 32) + return std::string("PK_B64"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return std::string("PK4_B32"); + return std::nullopt; +} + +std::optional getDenseLaneStrideStoreMaskGranularity( + VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 2 && elementBits == 32) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + +std::optional getDenseLaneStrideMaskedStoreMaskGranularity( + VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + +std::optional getPointStoreDistToken(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return std::nullopt; + return (Twine("1PT_B") + Twine(elementBits)).str(); +} + +std::optional getVPTOCmpMode(StringRef predicate) { + if (predicate == "eq" || predicate == "ne" || predicate == "lt" || + predicate == "le" || predicate == "gt" || predicate == "ge") + return predicate; + if (predicate == "oeq") + return StringRef("eq"); + if (predicate == "one") + return StringRef("ne"); + if (predicate == "olt") + return StringRef("lt"); + if (predicate == "ole") + return StringRef("le"); + if (predicate == "ogt") + return StringRef("gt"); + if (predicate == "oge") + return StringRef("ge"); + if (predicate == "slt") + return StringRef("lt"); + if (predicate == "sle") + return StringRef("le"); + if (predicate == "sgt") + return StringRef("gt"); + if (predicate == "sge") + return StringRef("ge"); + return std::nullopt; +} + +LogicalResult checkSupportedComparePredicate(Operation *op, + StringRef predicate) { + if (getVPTOCmpMode(predicate)) + return success(); + return op->emitError() + << kVMIDiagUnsupportedPrefix << "compare predicate " << predicate + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms oeq/one/olt/ole/ogt/oge, " + "and signed integer forms slt/sle/sgt/sge"; +} + +struct OneToNVMIUnpackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIUnpackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + if (sourceParts.size() != op->getNumResults()) + return rewriter.notifyMatchFailure( + op, "converted source part count must match unpack results"); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIPackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIPackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr arity = getVMIPhysicalArity(op.getResult().getType()); + if (failed(arity) || + static_cast(adaptor.getFlatOperands().size()) != *arity) + return rewriter.notifyMatchFailure( + op, "pack part count must match converted VMI result arity"); + rewriter.replaceOp(op, adaptor.getFlatOperands(), + adaptor.getResultMapping()); + return success(); + } +}; + +LogicalResult verifyIdentityPartForwarding(Operation *op, + ValueRange sourceParts, + TypeRange resultTypes, + PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "source and result physical arity mismatch"); + for (auto [part, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + if (part.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "helper requires non-identity physical materialization"); + } + return success(); +} + +FailureOr getUnsignedCarrierVRegType(MLIRContext *ctx, + unsigned elementBits) { + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return failure(); + auto elementType = + IntegerType::get(ctx, elementBits, + IntegerType::SignednessSemantics::Unsigned); + return VRegType::get(ctx, 2048 / elementBits, elementType); +} + +FailureOr bitcastVReg(Location loc, Value value, Type resultType, + PatternRewriter &rewriter) { + if (value.getType() == resultType) + return value; + auto inputType = dyn_cast(value.getType()); + auto outputType = dyn_cast(resultType); + if (!inputType || !outputType) + return failure(); + return rewriter.create(loc, outputType, value).getResult(); +} + +FailureOr unpackToNextCarrier(Location loc, Value source, + unsigned sourceBits, int64_t partIndex, + PatternRewriter &rewriter) { + FailureOr resultType = + getUnsignedCarrierVRegType(rewriter.getContext(), sourceBits * 2); + if (failed(resultType)) + return failure(); + Value part = rewriter.create(loc, partIndex); + return rewriter.create(loc, *resultType, source, part).getResult(); +} + +FailureOr packToPreviousCarrier(Location loc, Value source, + unsigned resultBits, + PatternRewriter &rewriter) { + FailureOr resultType = + getUnsignedCarrierVRegType(rewriter.getContext(), resultBits); + if (failed(resultType)) + return failure(); + return rewriter + .create(loc, *resultType, source, + rewriter.getStringAttr("LOWER")) + .getResult(); +} + +FailureOr> materializeContiguousToLaneStride( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + Type elementType, int64_t laneStride, PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) { + (void)rewriter.notifyMatchFailure( + op, "dense lane_stride unpack materialization requires matching " + "source/result physical arity"); + return failure(); + } + + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if ((laneStride != 2 && laneStride != 4) || + (laneStride == 4 && elementBits != 8) || + (elementBits != 8 && elementBits != 16)) { + (void)rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride unpack carrier shape"); + return failure(); + } + + MLIRContext *ctx = rewriter.getContext(); + FailureOr inputCarrier = + getUnsignedCarrierVRegType(ctx, elementBits); + if (failed(inputCarrier)) + return failure(); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [resultIndex, resultType] : llvm::enumerate(resultTypes)) { + int64_t sourceIndex = resultIndex / laneStride; + if (sourceIndex >= static_cast(sourceParts.size())) + return failure(); + Value source = sourceParts[sourceIndex]; + FailureOr current = + bitcastVReg(op->getLoc(), source, *inputCarrier, rewriter); + if (failed(current)) + return failure(); + int64_t part = resultIndex % laneStride; + FailureOr unpacked = + unpackToNextCarrier(op->getLoc(), *current, elementBits, + laneStride == 4 ? part / 2 : part, rewriter); + if (failed(unpacked)) + return failure(); + current = *unpacked; + if (laneStride == 4) { + unpacked = + unpackToNextCarrier(op->getLoc(), *current, elementBits * 2, + part % 2, rewriter); + if (failed(unpacked)) + return failure(); + current = *unpacked; + } + FailureOr result = + bitcastVReg(op->getLoc(), *current, resultType, rewriter); + if (failed(result)) + return failure(); + results.push_back(*result); + } + return results; +} + +FailureOr> materializeLaneStrideToContiguous( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + Type elementType, int64_t laneStride, PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) { + (void)rewriter.notifyMatchFailure( + op, "dense lane_stride pack materialization requires matching " + "source/result physical arity"); + return failure(); + } + + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if ((laneStride != 2 && laneStride != 4) || + (laneStride == 4 && elementBits != 8) || + (elementBits != 8 && elementBits != 16)) { + (void)rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride pack carrier shape"); + return failure(); + } + + unsigned carrierBits = + static_cast(elementBits * static_cast(laneStride)); + FailureOr sourceCarrier = + getUnsignedCarrierVRegType(rewriter.getContext(), carrierBits); + if (failed(sourceCarrier)) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + FailureOr current = + bitcastVReg(op->getLoc(), source, *sourceCarrier, rewriter); + if (failed(current)) + return failure(); + FailureOr packed = + packToPreviousCarrier(op->getLoc(), *current, carrierBits / 2, rewriter); + if (failed(packed)) + return failure(); + current = *packed; + if (laneStride == 4) { + packed = + packToPreviousCarrier(op->getLoc(), *current, elementBits, rewriter); + if (failed(packed)) + return failure(); + current = *packed; + } + FailureOr result = + bitcastVReg(op->getLoc(), *current, resultType, rewriter); + if (failed(result)) + return failure(); + results.push_back(*result); + } + return results; +} + +FailureOr> materializeDataLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "layout materialization requires assigned source/result layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; + bool contiguousToDeint2 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2 && + resultLayout.getLaneStride() == 1; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=2 layout materialization requires 2*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i]); + results.append({materialize.getLow(), materialize.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1]); + part0.push_back(materialize.getLow()); + part1.push_back(materialize.getHigh()); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; + bool contiguousToDeint4 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && + resultLayout.getLaneStride() == 1; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=4 layout materialization requires 4*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + auto even = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2); + auto odd = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3); + auto low = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], + even.getLow(), odd.getLow()); + auto high = rewriter.create( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even.getHigh(), odd.getHigh()); + results.append( + {low.getLow(), low.getHigh(), high.getLow(), high.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto low = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1]); + auto high = rewriter.create( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3]); + auto even = rewriter.create(op->getLoc(), resultTypes[i], + resultTypes[2 * groups + i], + low.getLow(), high.getLow()); + auto odd = rewriter.create( + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low.getHigh(), high.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() != 1) { + auto ensure = dyn_cast(op); + if (!ensure) + return failure(); + auto sourceType = cast(ensure.getSource().getType()); + return materializeContiguousToLaneStride( + op, sourceParts, resultTypes, sourceType.getElementType(), + resultLayout.getLaneStride(), rewriter); + } + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() != 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) { + auto ensure = dyn_cast(op); + if (!ensure) + return failure(); + auto sourceType = cast(ensure.getSource().getType()); + return materializeLaneStrideToContiguous( + op, sourceParts, resultTypes, sourceType.getElementType(), + sourceLayout.getLaneStride(), rewriter); + } + + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) { + VMILayoutAttr contiguous = + VMILayoutAttr::getContiguous(rewriter.getContext()); + FailureOr> dense = materializeDataLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, contiguous, rewriter); + if (failed(dense)) + return failure(); + return materializeDataLayoutConversion(op, *dense, resultTypes, contiguous, + resultLayout, rewriter); + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI data layout materialization"); + return failure(); +} + +FailureOr> +createPredicateDintlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> +createPredicateIntlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> materializeMaskLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "mask layout materialization requires assigned source/result " + "layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; + bool contiguousToDeint2 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2 && + resultLayout.getLaneStride() == 1; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=2 mask layout materialization requires 2*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = createPredicateIntlv( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({materialize->first, materialize->second}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(materialize->first); + part1.push_back(materialize->second); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; + bool contiguousToDeint4 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && + resultLayout.getLaneStride() == 1; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=4 mask layout materialization requires 4*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + FailureOr> even = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2, rewriter); + FailureOr> odd = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + FailureOr> low = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i], resultTypes[4 * i + 1], + even->first, odd->first, rewriter); + FailureOr> high = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even->second, odd->second, rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({low->first, low->second, high->first, high->second}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> low = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1], rewriter); + FailureOr> high = createPredicateDintlv( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3], rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + FailureOr> even = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], + low->first, high->first, rewriter); + FailureOr> odd = createPredicateDintlv( + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low->second, high->second, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(even->first); + part1.push_back(odd->first); + part2.push_back(even->second); + part3.push_back(odd->second); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI mask layout materialization"); + return failure(); +} + +int getMaskGranularityRank(StringRef granularity) { + if (granularity == "b8") + return 0; + if (granularity == "b16") + return 1; + if (granularity == "b32") + return 2; + return -1; +} + +StringRef getMaskGranularityForRank(int rank) { + switch (rank) { + case 0: + return "b8"; + case 1: + return "b16"; + case 2: + return "b32"; + default: + return ""; + } +} + +LogicalResult checkSupportedMaskGranularityMaterialization( + const VMITargetCapabilityRegistry &capabilities, VMIMaskType sourceType, + VMIMaskType resultType, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source and result mask lane counts to match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("requires source and result mask layouts to match"); + + VMICapabilityResult granularityCapability = + capabilities.supportsMaskGranularityConversion( + sourceType.getGranularity(), resultType.getGranularity()); + if (!granularityCapability.isSupported()) + return fail(granularityCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity < 1 || *resultArity < 1) + return fail("requires non-empty source/result physical arity"); + + return success(); +} + +FailureOr> materializeAdjacentMaskGranularityConversion( + Operation *op, VMIMaskType sourceType, VMIMaskType resultType, + ValueRange sourceParts, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + int sourceRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + if (std::abs(sourceRank - resultRank) != 1) + return fail("mask granularity conversion must be adjacent"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr factor = getVMITypeLayoutFactor(sourceType); + if (failed(sourceArity) || failed(factor) || + static_cast(sourceParts.size()) != *sourceArity) + return fail("source mask part count does not match source VMI type"); + + MLIRContext *ctx = op->getContext(); + auto partAttr = [&](StringRef part) { return StringAttr::get(ctx, part); }; + auto resultMaskType = MaskType::get(ctx, resultType.getGranularity()); + SmallVector results; + + int64_t sourceOffset = 0; + for (int64_t part = 0; part < *factor; ++part) { + FailureOr sourceChunks = getVMITypeChunksInPart(sourceType, part); + FailureOr resultChunks = getVMITypeChunksInPart(resultType, part); + if (failed(sourceChunks) || failed(resultChunks)) + return fail("requires computable source/result chunks per layout part"); + + if (resultRank > sourceRank) { + int64_t produced = 0; + for (int64_t chunk = 0; chunk < *sourceChunks && produced < *resultChunks; + ++chunk) { + Value source = sourceParts[sourceOffset + chunk]; + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("LOWER")) + .getResult()); + ++produced; + if (produced >= *resultChunks) + break; + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("HIGHER")) + .getResult()); + ++produced; + } + if (produced != *resultChunks) + return fail("widening mask granularity conversion produced the wrong " + "number of result chunks"); + } else { + Value allTrue; + int64_t consumed = 0; + for (int64_t chunk = 0; chunk < *resultChunks; ++chunk) { + if (consumed >= *sourceChunks) + return fail("narrowing mask granularity conversion ran out of " + "source chunks"); + Value lowerSource = sourceParts[sourceOffset + consumed++]; + Value packed = rewriter + .create(op->getLoc(), resultMaskType, + lowerSource, partAttr("LOWER")) + .getResult(); + if (consumed < *sourceChunks) { + Value higherSource = sourceParts[sourceOffset + consumed++]; + Value higher = rewriter + .create(op->getLoc(), resultMaskType, + higherSource, partAttr("HIGHER")) + .getResult(); + if (!allTrue) { + FailureOr mask = + createAllTrueMask(op->getLoc(), resultMaskType, rewriter); + if (failed(mask)) + return fail("failed to create all-true mask for ppack merge"); + allTrue = *mask; + } + packed = rewriter + .create(op->getLoc(), resultMaskType, packed, + higher, allTrue) + .getResult(); + } + results.push_back(packed); + } + if (consumed != *sourceChunks) + return fail("narrowing mask granularity conversion left unused source " + "chunks"); + } + + sourceOffset += *sourceChunks; + } + + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(resultArity) || + static_cast(results.size()) != *resultArity) + return fail("mask granularity conversion result count mismatch"); + return results; +} + +FailureOr> materializeMaskGranularityConversion( + Operation *op, const VMITargetCapabilityRegistry &capabilities, + VMIMaskType sourceType, VMIMaskType resultType, ValueRange sourceParts, + PatternRewriter &rewriter) { + std::string reason; + if (failed(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) { + (void)rewriter.notifyMatchFailure(op, reason); + return failure(); + } + + int currentRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + VMIMaskType currentType = sourceType; + SmallVector currentParts(sourceParts.begin(), sourceParts.end()); + + while (currentRank != resultRank) { + currentRank += currentRank < resultRank ? 1 : -1; + StringRef nextGranularity = getMaskGranularityForRank(currentRank); + if (nextGranularity.empty()) { + (void)rewriter.notifyMatchFailure(op, + "invalid target mask granularity rank"); + return failure(); + } + VMIMaskType nextType = + VMIMaskType::get(op->getContext(), currentType.getElementCount(), + nextGranularity, currentType.getLayoutAttr()); + FailureOr> nextParts = + materializeAdjacentMaskGranularityConversion(op, currentType, nextType, + currentParts, rewriter); + if (failed(nextParts)) + return failure(); + currentType = nextType; + currentParts = std::move(*nextParts); + } + + return currentParts; +} + +struct OneToNVMIEnsureLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeDataLayout(sourceType, resultType, + &supportReason))) + return rewriter.notifyMatchFailure( + op, + Twine("ensure_layout has no registered materialization support: ") + + supportReason); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + op, "ensure_layout requires assigned source/result layouts"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeDataLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIEnsureMaskLayoutOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureMaskLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, + &supportReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_mask_layout has no registered materialization " + "support: ") + + supportReason); + if (sourceType.getGranularity() != resultType.getGranularity()) + return rewriter.notifyMatchFailure( + op, "mask layout helper cannot also change granularity"); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeMaskLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskGranularityOpPattern + : OneToNOpConversionPattern { + OneToNVMIEnsureMaskGranularityOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, + context), + capabilities(capabilities) {} + + LogicalResult + matchAndRewrite(VMIEnsureMaskGranularityOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &supportReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_mask_granularity has no registered materialization " + "support: ") + + supportReason); + if (sourceType.getLayout() != resultType.getLayout()) + return rewriter.notifyMatchFailure( + op, "mask granularity helper cannot also change layout"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceType.getGranularity() != resultType.getGranularity()) { + FailureOr> results = + materializeMaskGranularityConversion( + op, capabilities, sourceType, resultType, sourceParts, rewriter); + if (failed(results)) + return failure(); + if (results->size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "mask granularity result arity mismatch"); + for (auto [result, type] : llvm::zip_equal(*results, resultTypes)) + if (result.getType() != type) + return rewriter.notifyMatchFailure( + op, "mask granularity result type mismatch"); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } + +private: + const VMITargetCapabilityRegistry &capabilities; +}; + +struct OneToNVMIBroadcastOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange inputParts = adaptor.getValue(); + if (inputParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "broadcast input must convert to one value"); + bool inputIsVReg = isa(op.getValue().getType()); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "broadcast result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for broadcast mask"); + StringAttr position = + inputIsVReg ? rewriter.getStringAttr("LOWEST") : StringAttr{}; + results.push_back(rewriter + .create(op.getLoc(), resultType, + inputParts.front(), *mask, position) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter) { + if (auto intType = dyn_cast(type)) { + return rewriter + .create(loc, IntegerAttr::get(intType, value)) + .getResult(); + } + if (auto floatType = dyn_cast(type)) { + return rewriter + .create( + loc, rewriter.getFloatAttr(floatType, static_cast(value))) + .getResult(); + } + return failure(); +} + +FailureOr createIotaChunkBase(Location loc, Value base, + int64_t laneOffset, StringRef order, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return base; + + FailureOr offset = + createScalarOffsetConstant(loc, base.getType(), laneOffset, rewriter); + if (failed(offset)) + return failure(); + + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + + return failure(); +} + +FailureOr createIotaContiguousChunk(Location loc, Type resultType, + Value base, int64_t laneOffset, + StringAttr orderAttr, + PatternRewriter &rewriter) { + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + FailureOr chunkBase = + createIotaChunkBase(loc, base, laneOffset, order, rewriter); + if (failed(chunkBase)) + return failure(); + return rewriter.create(loc, resultType, *chunkBase, orderAttr) + .getResult(); +} + +FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, + Value base, int64_t factor, + int64_t part, int64_t chunk, + int64_t lanesPerPart, + StringAttr orderAttr, + PatternRewriter &rewriter) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return failure(); + + FailureOr mask = createAllTrueMaskForVReg(loc, vregType, rewriter); + FailureOr zero = + createScalarOffsetConstant(loc, base.getType(), 0, rewriter); + FailureOr factorScalar = + createScalarOffsetConstant(loc, base.getType(), factor, rewriter); + if (failed(mask) || failed(zero) || failed(factorScalar)) + return failure(); + + Value local = + rewriter.create(loc, resultType, *zero, StringAttr{}).getResult(); + Value scaled = + rewriter.create(loc, resultType, local, *factorScalar, *mask) + .getResult(); + + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + int64_t partOffset = part + factor * chunk * lanesPerPart; + FailureOr biasedBase = + createIotaChunkBase(loc, base, partOffset, order, rewriter); + if (failed(biasedBase)) + return failure(); + + if (order == "DESC") { + Value baseVector = rewriter + .create(loc, resultType, *biasedBase, *mask, + /*position=*/nullptr) + .getResult(); + return rewriter.create(loc, resultType, baseVector, scaled, *mask) + .getResult(); + } + + return rewriter.create(loc, resultType, scaled, *biasedBase, *mask) + .getResult(); +} + +struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIIotaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure(op, "iota requires assigned layout"); + + FailureOr lanesPerPart = + getDataLanesPerPart(resultVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "iota requires known physical lanes per part"); + + FailureOr base = getSingleValue( + op, adaptor.getBase(), "iota base must convert to one value", rewriter); + if (failed(base)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + + if (layout.isContiguous()) { + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, "iota result must be vreg"); + FailureOr result = createIotaContiguousChunk( + op.getLoc(), resultType, *base, + static_cast(index) * *lanesPerPart, op.getOrderAttr(), + rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize contiguous iota chunk"); + results.push_back(*result); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t factor = layout.getFactor(); + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "deinterleaved iota physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + FailureOr result = createIotaDeinterleavedChunk( + op.getLoc(), resultType, *base, factor, part, chunk, *lanesPerPart, + op.getOrderAttr(), rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize deinterleaved iota chunk"); + results.push_back(*result); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr || !denseAttr.isSplat()) + return rewriter.notifyMatchFailure( + op, "only splat dense data constants are supported"); + auto splatAttr = dyn_cast(denseAttr.getSplatValue()); + if (!splatAttr) + return rewriter.notifyMatchFailure(op, "splat constant must be typed"); + + Value scalar = + rewriter.create(op.getLoc(), splatAttr).getResult(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "constant result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for constant mask"); + results.push_back(rewriter + .create(op.getLoc(), resultType, scalar, + *mask, + /*position=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> materializations = + computeConstantMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure(op, Twine("constant_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "constant_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize constant_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICreateMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto activeConstant = + op.getActiveLanes().getDefiningOp(); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return rewriter.notifyMatchFailure( + op, "create_mask requires concrete layout and granularity"); + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "create_mask requires known physical mask lanes per part"); + + if (!activeConstant) { + FailureOr active = getSingleValue( + op, adaptor.getActiveLanes(), + "create_mask active_lanes must convert to one value", rewriter); + if (failed(active)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "dynamic create_mask physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + Value activeI32 = clampDynamicActiveLanes( + op.getLoc(), *active, resultVMIType.getElementCount(), rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t part = 0; part < factor; ++part) { + Value remaining = createPartitionActiveLanes(op.getLoc(), activeI32, + factor, part, rewriter); + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + auto maskType = dyn_cast(resultType); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_mask result must be mask"); + FailureOr> maskAndRemaining = + createRuntimePrefixMask(op.getLoc(), maskType, remaining, + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for dynamic create_mask"); + results.push_back(maskAndRemaining->first); + remaining = maskAndRemaining->second; + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return rewriter.notifyMatchFailure( + op, "create_mask active_lanes must be an integer constant"); + + int64_t activeLanes = activeAttr.getInt(); + if (activeLanes < 0) + activeLanes = 0; + if (activeLanes > resultVMIType.getElementCount()) + activeLanes = resultVMIType.getElementCount(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + SmallVector results; + results.reserve(resultTypes.size()); + + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + int64_t activeInChunk = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical padding lane"); + if (*padding) + continue; + anyLane = true; + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical lane"); + if (*logicalLane < activeLanes) + ++activeInChunk; + } + if (!anyLane) + break; + + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "create_mask result must be mask"); + std::optional pattern = + getPrefixPattern(activeInChunk, *lanesPerPart); + if (pattern) { + FailureOr mask = + createPrefixMask(op.getLoc(), maskType, *pattern, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask"); + results.push_back(*mask); + continue; + } + + FailureOr> maskAndRemaining = + createRuntimePrefixMask( + op.getLoc(), maskType, + createI32Constant(op.getLoc(), activeInChunk, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask plt fallback"); + results.push_back(maskAndRemaining->first); + } + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICreateGroupMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICreateGroupMaskOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateGroupMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 8) { + VMILayoutAttr contiguousLayout = + VMILayoutAttr::getContiguous(op.getContext()); + auto contiguousType = + VMIMaskType::get(op.getContext(), resultVMIType.getElementCount(), + resultVMIType.getGranularity(), contiguousLayout); + SmallVector contiguousParts; + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (activeConstant) { + std::string contiguousReason; + FailureOr> + contiguousMaterializations = computeGroupMaskMaterializationForType( + op, contiguousType, &contiguousReason); + if (failed(contiguousMaterializations)) + return rewriter.notifyMatchFailure(op, Twine("create_group_mask ") + + contiguousReason); + + contiguousParts.reserve(contiguousMaterializations->size()); + for (const ConstantMaskChunkMaterialization &materialization : + *contiguousMaterializations) { + if (contiguousParts.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many contiguous masks"); + auto maskType = + dyn_cast(resultTypes[contiguousParts.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask contiguous chunk"); + contiguousParts.push_back(*mask); + } + } else { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one " + "value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> dynamicParts = + materializeDynamicContiguousGroupMask(op, *active, contiguousType, + resultTypes, rewriter); + if (failed(dynamicParts)) + return failure(); + contiguousParts = std::move(*dynamicParts); + } + + if (contiguousParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask contiguous physical result count mismatch"); + FailureOr> results = materializeMaskLayoutConversion( + op, contiguousParts, resultTypes, contiguousLayout, resultLayout, + rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant && resultLayout && resultLayout.isContiguous()) { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> results = + materializeDynamicContiguousGroupMask(op, *active, resultVMIType, + resultTypes, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + + std::string reason; + FailureOr> materializations = + computeGroupMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure(op, + Twine("create_group_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMILoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "load source must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "load offset must convert to one value", rewriter); + if (failed(source) || failed(offset)) + return failure(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (std::optional dist = + getDenseLaneStrideLoadDistToken(resultVMIType)) { + SmallVector results; + results.reserve(resultTypes.size()); + int64_t semanticOffset = 0; + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, "load result must be vreg"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + rewriter.getStringAttr(*dist)) + .getResult()); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(resultVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride load active lanes"); + semanticOffset += *activeLanes; + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, op.getSource().getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { + int64_t groups = resultTypes.size() / 2; + SmallVector lows; + SmallVector highs; + lows.reserve(groups); + highs.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type lowType = resultTypes[group]; + Type highType = resultTypes[groups + group]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "vldsx2 requires matching low/high result types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + auto load = rewriter.create(op.getLoc(), lowType, highType, + *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + SmallVector results; + results.reserve(resultTypes.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 1) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 4 == 0) { + int64_t groups = resultTypes.size() / 4; + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type part0Type = resultTypes[group]; + Type part1Type = resultTypes[groups + group]; + Type part2Type = resultTypes[2 * groups + group]; + Type part3Type = resultTypes[3 * groups + group]; + if (part0Type != part1Type || part0Type != part2Type || + part0Type != part3Type) + return rewriter.notifyMatchFailure( + op, "vldsx2 deinterleaved=4 load requires matching part " + "types"); + + Value firstOffset = createChunkOffset( + op.getLoc(), *offset, group * 4 * *lanesPerPart, rewriter); + Value secondOffset = createChunkOffset( + op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, rewriter); + auto first = rewriter.create( + op.getLoc(), part0Type, part1Type, *source, firstOffset, + rewriter.getStringAttr(*dist)); + auto second = rewriter.create( + op.getLoc(), part2Type, part3Type, *source, secondOffset, + rewriter.getStringAttr(*dist)); + + auto even = + rewriter.create(op.getLoc(), part0Type, part2Type, + first.getLow(), second.getLow()); + auto odd = + rewriter.create(op.getLoc(), part1Type, part3Type, + first.getHigh(), second.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + + SmallVector results; + results.reserve(resultTypes.size()); + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + SmallVector contiguousParts; + contiguousParts.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "load result must be vreg"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + contiguousParts.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, + *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + FailureOr> results = materializeDataLayoutConversion( + op, contiguousParts, resultTypes, + VMILayoutAttr::getContiguous(rewriter.getContext()), + resultVMIType.getLayoutAttr(), rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIDeinterleaveLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIDeinterleaveLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIDeinterleaveLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto lowVMIType = cast(op.getLow().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "deinterleave_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "deinterleave_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + FailureOr lanesPerPart = + getDataLanesPerPart(lowVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires known physical lanes per part"); + + std::optional dist = + getX2MemoryDistToken(lowVMIType.getElementType(), "DINTLV"); + if (!dist) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires vldsx2 DINTLV element support"); + + TypeRange lowTypes = adaptor.getResultMapping().getConvertedTypes(0); + TypeRange highTypes = adaptor.getResultMapping().getConvertedTypes(1); + if (lowTypes.size() != highTypes.size()) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires matching low/high physical arity"); + + SmallVector lows; + SmallVector highs; + lows.reserve(lowTypes.size()); + highs.reserve(highTypes.size()); + for (size_t index = 0, e = lowTypes.size(); index < e; ++index) { + Type lowType = lowTypes[index]; + Type highType = highTypes[index]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires matching low/high physical types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, static_cast(index) * 2 * *lanesPerPart, + rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + + SmallVector results; + results.reserve(lows.size() + highs.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGroupLoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_load source must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_load offset must convert to one value", rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_load row_stride must convert to one value", rewriter); + if (failed(source) || failed(offset) || failed(rowStride)) + return failure(); + + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 8 && + resultVMIType.getElementType().isF32()) { + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if ((*groupSize != 16 || resultLayout.getFactor() != 2) && + (*groupSize != 32 || resultLayout.getFactor() != 4)) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires S=16/factor=2 or S=32/factor=4"); + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires num_groups multiple of 8"); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride <= 0 || + *constantRowStride % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires constant positive row_stride " + "divisible by 8 f32 elements"); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = resultLayout.getFactor(); + FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); + if (failed(chunksPerPart) || *chunksPerPart <= 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires known chunks per part"); + for (int64_t part = 1; part < factor; ++part) { + FailureOr currentChunks = + getDataChunksInPart(resultVMIType, part); + if (failed(currentChunks) || *currentChunks != *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires uniform chunks per part"); + } + if (static_cast(resultTypes.size()) != factor * *chunksPerPart) + return rewriter.notifyMatchFailure(op, + "block8 group_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op.getLoc(), value, 16); + }; + Value blockStride = makeI16(*constantRowStride / 8); + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op.getLoc(), (*source).getType(), *source, + elementOffset) + .getResult(); + }; + + SmallVector results; + results.reserve(resultTypes.size()); + constexpr int64_t kGroupsPerBlock8Load = 8; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + int64_t flatIndex = part * *chunksPerPart + chunk; + auto vregType = dyn_cast(resultTypes[flatIndex]); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "block8 group_load result must be vreg"); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create block8 group_load mask"); + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, chunk * kGroupsPerBlock8Load, + part * resultLayout.getBlockElems(), rewriter); + Value chunkBase = makePtr(chunkOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, + chunkBase, blockStride, + zeroI16, *allMask) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks(op, resultVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (static_cast(resultTypes.size()) != groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_load arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_load result must be vreg"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +static LogicalResult lowerGroupSlotLoadParts( + Operation *op, Value source, Value offset, Value sourceGroupStride, + VMIVRegType resultVMIType, TypeRange resultTypes, int64_t numGroups, + OneToNPatternRewriter &rewriter, SmallVectorImpl &results) { + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires explicit group_slots layout"); + if (!isa(source.getType())) + return rewriter.notifyMatchFailure(op, + "group_slot_load requires !pto.ptr source"); + + int64_t slots = layout.getSlots(); + int64_t expectedArity = ceilDivNonNegative(numGroups, slots); + if (static_cast(resultTypes.size()) != expectedArity) + return rewriter.notifyMatchFailure(op, "group_slot_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op->getLoc(), value, 16); + }; + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op->getLoc(), source.getType(), source, + elementOffset) + .getResult(); + }; + + results.reserve(results.size() + resultTypes.size()); + + if (slots == 8) { + std::optional stride = getConstantIndexValue(sourceGroupStride); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load requires constant unit stride"); + for (auto [chunk, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + int64_t groupBegin = static_cast(chunk) * slots; + int64_t activeGroups = std::min(slots, numGroups - groupBegin); + if (activeGroups <= 0) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load has no active groups for chunk"); + std::string pattern = (Twine("PAT_VL") + Twine(activeGroups)).str(); + FailureOr slotMask = + createPrefixMask(op->getLoc(), *maskType, pattern, rewriter); + if (failed(slotMask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_slot_load mask"); + Value groupOffset = + createChunkOffset(op->getLoc(), offset, groupBegin, rewriter); + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op->getLoc(), vregType, slotBase, + zeroI16, zeroI16, *slotMask) + .getResult()); + } + return success(); + } + + if (slots != 1) + return rewriter.notifyMatchFailure( + op, "group_slot_load supports only slots=8 or slots=1"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_slot_load requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantStride = + getConstantIndexValue(sourceGroupStride); + if (!constantStride || *constantStride <= 0 || + *constantStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_slot_load requires constant positive " + "source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsldb alignment"); + + for (auto [group, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op->getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure(op, + "failed to create group_slot_load mask"); + Value groupOffset = offset; + if (group != 0) { + Value groupIndex = + rewriter.create(op->getLoc(), group); + Value rowOffset = + rewriter + .create(op->getLoc(), sourceGroupStride, + groupIndex) + .getResult(); + groupOffset = + rewriter.create(op->getLoc(), groupOffset, rowOffset) + .getResult(); + } + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op->getLoc(), vregType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); + } + return success(); +} + +static LogicalResult lowerGroupBroadcastParts( + Operation *op, ValueRange sourceParts, VMIVRegType sourceVMIType, + VMIVRegType resultVMIType, TypeRange resultTypes, int64_t numGroups, + OneToNPatternRewriter &rewriter, SmallVectorImpl &results) { + FailureOr groupSize = + getGroupSizeFromNumGroups(resultVMIType, numGroups); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires num_groups to evenly divide lane count"); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + if (failed(checkFullGroupSlotSourceShape(op, sourceVMIType, *groupSize, + numGroups, &lanesPerPart, + &groupCount, rewriter))) + return failure(); + int64_t resultLayoutFactor = 0; + int64_t resultGroupCount = 0; + if (failed(checkFullGroupBroadcastResultShape( + op, resultVMIType, *groupSize, lanesPerPart, &resultLayoutFactor, + &resultGroupCount, rewriter))) + return failure(); + if (resultGroupCount != groupCount) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires matching source/result group slots"); + + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "group_broadcast arity mismatch"); + + auto firstSourceType = dyn_cast(sourceParts.front().getType()); + if (!firstSourceType) + return rewriter.notifyMatchFailure(op, + "group_broadcast source must be vreg"); + unsigned indexBits = + pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires 8/16/32-bit index elements"); + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = + VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), + indexElementType); + FailureOr allMask = + createAllTrueMaskForVReg(op->getLoc(), firstSourceType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast all mask"); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t selectionGroupSize = *groupSize; + if (resultLayoutFactor != 1 && resultLayout && + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < lanesPerPart) + selectionGroupSize = resultLayout.getBlockElems(); + auto resolveLargeGroupSource = [&](int64_t group, int64_t chunksPerGroup, + int64_t &sourceChunk, + int64_t &baseGroupSlot) { + int64_t slots = sourceLayout.getSlots(); + if (slots > 0) { + sourceChunk = group / slots; + baseGroupSlot = group % slots; + return; + } + sourceChunk = group * chunksPerGroup; + baseGroupSlot = 0; + }; + + results.clear(); + results.resize(resultTypes.size()); + for (auto [flatIndex, resultType] : llvm::enumerate(resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires uniform physical vreg types"); + int64_t sourceChunk = flatIndex; + int64_t baseGroupSlot = 0; + Value mappedGroupSlotIndex; + if (resultLayoutFactor == 1) { + if (*groupSize >= lanesPerPart) { + int64_t chunksPerGroup = *groupSize / lanesPerPart; + int64_t group = flatIndex / chunksPerGroup; + resolveLargeGroupSource(group, chunksPerGroup, sourceChunk, + baseGroupSlot); + } else { + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + int64_t groupsPerResultChunk = lanesPerPart / *groupSize; + int64_t firstGroup = flatIndex * groupsPerResultChunk; + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + } + } else { + bool blockFragmentSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + bool deinterleavedSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1 && *groupSize < lanesPerPart; + if (blockFragmentSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t groupsPerResultChunk = + lanesPerPart / resultLayout.getBlockElems(); + int64_t firstGroup = chunk * groupsPerResultChunk; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, + "group_broadcast block-fragment source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else if (deinterleavedSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast deinterleaved small-group source " + "requires explicit group_slots slots or derivable " + "legacy slot count"); + slots = groupCount / sourceParts.size(); + } + FailureOr index = createMappedGroupSlotIndexVector( + op->getLoc(), resultVMIType, part, chunk, indexType, + *groupSize, slots, sourceChunk, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, + "failed to create group_broadcast mapped group-slot index " + "vector"); + mappedGroupSlotIndex = *index; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + resolveLargeGroupSource(firstGroup, chunksPerGroup, sourceChunk, + baseGroupSlot); + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } + } + if (*groupSize >= lanesPerPart) { + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + if (sourceLayout.getSlots() > 1) { + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op->getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } else { + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) + .getResult(); + } + } else { + bool blockFragmentSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1; + bool deinterleavedSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup && + !deinterleavedSmallGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group deinterleaved result is not " + "supported"); + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + FailureOr groupSlotIndex = + mappedGroupSlotIndex + ? FailureOr(mappedGroupSlotIndex) + : createGroupSlotIndexVector(op->getLoc(), indexType, + selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } + } + return success(); +} + +struct OneToNVMIGroupSlotLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupSlotLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupSlotLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires explicit group_slots layout"); + + FailureOr source = getSingleValue( + op, adaptor.getSource(), + "group_slot_load source must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "group_slot_load offset must convert to one value", rewriter); + FailureOr sourceGroupStride = getSingleValue( + op, adaptor.getSourceGroupStride(), + "group_slot_load source_group_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(sourceGroupStride)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + SmallVector results; + if (failed(lowerGroupSlotLoadParts(op, *source, *offset, *sourceGroupStride, + resultVMIType, resultTypes, numGroups, + rewriter, results))) + return failure(); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIMaskedLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "masked_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "masked_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != passthruParts.size() || + passthruParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "masked_load physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, maskPassthruAndType] : llvm::enumerate( + llvm::zip_equal(maskParts, passthruParts, resultTypes))) { + auto [mask, passthru, resultType] = maskPassthruAndType; + if (!isa(mask.getType()) || passthru.getType() != resultType || + !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "masked_load physical part type mismatch"); + + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + Value loaded = + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, loaded, passthru, mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGatherOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "gather source must convert to one value", rewriter); + if (failed(source)) + return failure(); + + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (indicesParts.size() != maskParts.size() || + indicesParts.size() != passthruParts.size() || + indicesParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "gather physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [indices, mask, passthru, resultType] : + llvm::zip_equal(indicesParts, maskParts, passthruParts, resultTypes)) { + if (!isa(indices.getType()) || !isa(mask.getType()) || + passthru.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "gather physical part type mismatch"); + + unsigned resultBits = pto::getPTOStorageElemBitWidth( + cast(resultType).getElementType()); + Value gathered = resultBits == 16 + ? rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult() + : rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, gathered, passthru, mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExpandLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExpandLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "expand_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "expand_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (isStaticAllActiveMask(op.getMask(), resultVMIType.getElementCount())) { + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, + "expand_load result must be vreg"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + if (resultTypes.size() != 1 || maskParts.size() != 1 || + passthruParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "runtime expand_load supports only one physical chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || + passthruParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "runtime expand_load requires physical result/passthru/mask"); + + auto baseType = dyn_cast((*source).getType()); + if (!baseType) + return rewriter.notifyMatchFailure(op, + "runtime expand_load requires ptr"); + Value gatherBase = rewriter + .create(op.getLoc(), (*source).getType(), + *source, *offset) + .getResult(); + auto indexType = + VRegType::get(rewriter.getContext(), resultType.getElementCount(), + rewriter.getI32Type()); + FailureOr indexSeedMask = + createAllTrueMaskForVReg(op.getLoc(), indexType, rewriter); + if (failed(indexSeedMask)) + return rewriter.notifyMatchFailure( + op, "failed to create runtime expand_load index seed mask"); + Value zero = rewriter.create(op.getLoc(), 0, 32); + Value carrier = + rewriter + .create(op.getLoc(), indexType, zero, *indexSeedMask, + /*position=*/nullptr) + .getResult(); + Value indices = + rewriter + .create(op.getLoc(), indexType, carrier, maskParts.front()) + .getResult(); + Value gathered = + rewriter + .create(op.getLoc(), resultType, gatherBase, indices, + maskParts.front()) + .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, gathered, + passthruParts.front(), maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "store requires known physical lanes per part"); + bool fullPhysicalChunks = + succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "store destination must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "store offset must convert to one value", rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + if (std::optional dist = + getDenseLaneStrideStoreDistToken(valueVMIType)) { + std::optional maskGranularity = + getDenseLaneStrideStoreMaskGranularity(valueVMIType); + if (!maskGranularity) + return rewriter.notifyMatchFailure( + op, "unsupported lane_stride store mask granularity"); + int64_t semanticOffset = 0; + for (auto [index, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride store active lanes"); + if (*activeLanes == 0) + continue; + auto maskType = MaskType::get(rewriter.getContext(), *maskGranularity); + FailureOr mask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, *activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create lane_stride store mask"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + semanticOffset += *activeLanes; + } + rewriter.eraseOp(op); + return success(); + } + + VMILayoutSupport localSupports; + FailureOr storeSupport = + localSupports.getContiguousStoreSupport(valueVMIType); + if (succeeded(storeSupport) && + storeSupport->kind == + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { + std::optional dist = + getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); + if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { + int64_t groups = valueParts.size() / 2; + for (int64_t group = 0; group < groups; ++group) { + Value low = valueParts[group]; + Value high = valueParts[groups + group]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "vstsx2 requires matching low/high value types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousTypes; + contiguousTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousTypes.push_back(value.getType()); + + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + for (auto [index, value] : llvm::enumerate(*storeParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + if (!fullPhysicalChunks) { + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute store active lanes"); + if (*activeLanes == 0) + continue; + } + FailureOr mask = + fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), valueVMIType, index, + vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIInterleaveStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIInterleaveStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIInterleaveStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto lowVMIType = cast(op.getLow().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(lowVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "interleave_store requires known physical lanes per part"); + + std::optional dist = + getX2MemoryDistToken(lowVMIType.getElementType(), "INTLV"); + if (!dist) + return rewriter.notifyMatchFailure( + op, "interleave_store requires vstsx2 INTLV element support"); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "interleave_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "interleave_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange lowParts = adaptor.getLow(); + ValueRange highParts = adaptor.getHigh(); + if (lowParts.size() != highParts.size()) + return rewriter.notifyMatchFailure( + op, "interleave_store requires matching low/high physical arity"); + + for (size_t index = 0, e = lowParts.size(); index < e; ++index) { + Value low = lowParts[index]; + Value high = highParts[index]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "interleave_store requires matching low/high physical types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "interleave_store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for interleave_store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, static_cast(index) * 2 * *lanesPerPart, + rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIGroupStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueVMIType.getLayoutAttr(); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "group_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "group_store offset must convert to one value", + rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_store row_stride must convert to one value", rewriter); + if (failed(destination) || failed(offset) || failed(rowStride)) + return failure(); + + if (layout && layout.isGroupSlots() && layout.getSlots() == 1 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != layout.getNumGroups()) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store arity mismatch"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires supported element width"); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + int64_t alignedStoreElems = 256 / elementBits; + if (constantRowStride && *constantRowStride == 1 && + succeeded(lanesPerPart) && layout.getNumGroups() <= *lanesPerPart && + isKnownIndexMultipleOf(op.getOffset(), alignedStoreElems)) { + auto firstType = dyn_cast(valueParts.front().getType()); + if (!firstType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(firstType, rewriter.getContext()); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstType, rewriter); + if (failed(maskType) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for packed group_store mask"); + + Value packed = + rewriter + .create(op.getLoc(), firstType, valueParts.front(), + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + for (int64_t group = 1; group < layout.getNumGroups(); ++group) { + auto vregType = dyn_cast(valueParts[group].getType()); + if (!vregType || vregType != firstType) + return rewriter.notifyMatchFailure( + op, "packed group_store requires uniform vreg parts"); + Value splat = + rewriter + .create(op.getLoc(), firstType, valueParts[group], + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + FailureOr laneMask = createLaneRangeMask( + op.getLoc(), *maskType, group, group + 1, rewriter); + if (failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane mask"); + packed = rewriter + .create(op.getLoc(), firstType, splat, packed, + *laneMask) + .getResult(); + } + + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, layout.getNumGroups(), rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store store mask"); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, packed, *destination, + *offset, /*dist=*/nullptr, *storeMask); + rewriter.eraseOp(op); + return success(); + } + if (constantRowStride && *constantRowStride <= 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires positive row_stride when " + "row_stride is constant"); + std::optional pointDist = + getPointStoreDistToken(valueVMIType.getElementType()); + if (!pointDist) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires 1PT_B8/B16/B32 store support"); + + for (auto [group, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + FailureOr mask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=1 group_store mask"); + Value groupOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, rewriter.getStringAttr(*pointDist), + *mask); + } + + rewriter.eraseOp(op); + return success(); + } + + if (layout && layout.isGroupSlots() && layout.getSlots() == 8 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + int64_t numGroups = layout.getNumGroups(); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store requires constant unit row_stride"); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != + ceilDivNonNegative(numGroups, 8)) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store arity mismatch"); + + if (!valueParts.empty()) { + auto firstVRegType = dyn_cast(valueParts.front().getType()); + if (!firstVRegType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + bool packedByteStore = isPackedByteGroupStore( + op.getDestination().getType(), firstVRegType); + if (packedByteStore) { + bool laneStridedPackedByteStore = layout.hasLaneStride(); + for (Value value : valueParts) { + auto vregType = dyn_cast(value.getType()); + if (!vregType || vregType != firstVRegType) + return rewriter.notifyMatchFailure( + op, "packed slots=8 group_store requires uniform vreg parts"); + } + + FailureOr maskType = + getMaskTypeForVReg(firstVRegType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for packed group_store mask"); + if (!laneStridedPackedByteStore && numGroups == 8 && + valueParts.size() == 1 && + isKnownIndexMultipleOf(*offset, 32)) { + MLIRContext *ctx = rewriter.getContext(); + auto ui16 = IntegerType::get( + ctx, 16, IntegerType::SignednessSemantics::Unsigned); + auto ui8 = IntegerType::get( + ctx, 8, IntegerType::SignednessSemantics::Unsigned); + auto packed16Type = VRegType::get(ctx, 128, ui16); + auto packed8Type = VRegType::get(ctx, 256, ui8); + Value packed16 = + rewriter + .create(op.getLoc(), packed16Type, + valueParts.front(), + rewriter.getStringAttr("LOWER")) + .getResult(); + Value packed8 = + rewriter + .create(op.getLoc(), packed8Type, packed16, + rewriter.getStringAttr("LOWER")) + .getResult(); + FailureOr packedMaskType = + getMaskTypeForVReg(packed8Type, ctx); + if (failed(packedMaskType)) + return rewriter.notifyMatchFailure( + op, "failed to create packed byte group_store mask type"); + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *packedMaskType, numGroups, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed byte group_store mask"); + rewriter.create( + op.getLoc(), /*updated_base=*/Type{}, packed8, *destination, + *offset, rewriter.getStringAttr("NORM_B8"), *storeMask); + rewriter.eraseOp(op); + return success(); + } + + auto indexElementType = IntegerType::get( + rewriter.getContext(), + pto::getPTOStorageElemBitWidth(firstVRegType.getElementType())); + auto indexType = + VRegType::get(rewriter.getContext(), + firstVRegType.getElementCount(), indexElementType); + FailureOr slotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, /*groupSize=*/8, /*baseGroupSlot=*/0, + rewriter); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstVRegType, rewriter); + if (failed(slotIndex) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane selector"); + + for (int64_t blockStart = 0; blockStart < numGroups; + blockStart += 32) { + FailureOr zero = + createZeroVector(op.getLoc(), firstVRegType, rewriter); + if (failed(zero)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store accumulator"); + Value merged = *zero; + for (int64_t localPart = 0; localPart < 4; ++localPart) { + int64_t partIndex = blockStart / 8 + localPart; + if (partIndex >= static_cast(valueParts.size())) + break; + int64_t remainingGroups = numGroups - partIndex * 8; + int64_t activeGroups = std::min(8, remainingGroups); + if (activeGroups <= 0) + break; + Value selected = + rewriter + .create(op.getLoc(), firstVRegType, + valueParts[partIndex], *slotIndex) + .getResult(); + FailureOr laneMask = + createLaneRangeMask(op.getLoc(), *maskType, localPart * 8, + localPart * 8 + activeGroups, rewriter); + if (failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane mask"); + merged = rewriter + .create(op.getLoc(), firstVRegType, selected, + merged, *laneMask) + .getResult(); + } + + int64_t activeGroups = + std::min(32, numGroups - blockStart); + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, activeGroups, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store store mask"); + Value groupOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, blockStart / 4, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create( + op.getLoc(), /*updated_base=*/Type{}, merged, *destination, + groupOffset, rewriter.getStringAttr("PK4_B32"), *storeMask); + } + + rewriter.eraseOp(op); + return success(); + } + } + + for (auto [slotBlock, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t activeGroups = std::min(8, numGroups - slotBlock * 8); + FailureOr mask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, activeGroups, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_store mask"); + Value groupOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, slotBlock * 8, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = + getGroupSizeFromNumGroups(valueVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_store requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks(op, valueVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_store arity mismatch"); + + for (auto [index, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIMaskedStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "masked_store requires known physical lanes per part"); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "masked_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "masked_store offset must convert to one value", rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "masked_store value/mask physical arity mismatch"); + + auto maskVMIType = cast(op.getMask().getType()); + if (std::optional dist = + getDenseLaneStrideStoreDistToken(valueVMIType)) { + std::optional maskGranularity = + getDenseLaneStrideMaskedStoreMaskGranularity(valueVMIType); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskVMIType.getLayoutAttr(); + if (maskGranularity && valueLayout && maskLayout && + valueLayout == maskLayout) { + int64_t semanticOffset = 0; + for (auto [index, valueAndMask] : + llvm::enumerate(llvm::zip_equal(valueParts, maskParts))) { + auto [value, mask] = valueAndMask; + auto vregType = dyn_cast(value.getType()); + if (!vregType || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "lane_stride masked_store parts must be vreg/mask"); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride masked_store active lanes"); + if (*activeLanes == 0) + continue; + FailureOr storeMask = createDenseLaneStrideStorePredicate( + op.getLoc(), valueVMIType, index, mask, *maskGranularity, + rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to compact lane_stride masked_store predicate"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *storeMask); + semanticOffset += *activeLanes; + } + + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousValueTypes; + contiguousValueTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousValueTypes.push_back(value.getType()); + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousValueTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + SmallVector contiguousMaskTypes; + contiguousMaskTypes.reserve(maskParts.size()); + for (Value mask : maskParts) + contiguousMaskTypes.push_back(mask.getType()); + FailureOr> storeMasks = materializeMaskLayoutConversion( + op, maskParts, contiguousMaskTypes, maskVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeMasks)) + return failure(); + + if (storeParts->size() != storeMasks->size()) + return rewriter.notifyMatchFailure( + op, "masked_store converted value/mask arity mismatch"); + + for (auto [index, valueAndMask] : + llvm::enumerate(llvm::zip_equal(*storeParts, *storeMasks))) { + auto [value, mask] = valueAndMask; + auto vregType = dyn_cast(value.getType()); + if (!vregType || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "masked_store converted parts must be vreg/mask"); + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute masked_store active lanes"); + if (*activeLanes == 0) + continue; + FailureOr storeMask = createMaskedStorePredicate( + op.getLoc(), valueVMIType, index, mask, vregType, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize masked_store predicate"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *storeMask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIGroupBroadcastLoadOpPattern + : OneToNOpConversionPattern { + OneToNVMIGroupBroadcastLoadOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, + context), + capabilities(capabilities) {} + + LogicalResult + matchAndRewrite(VMIGroupBroadcastLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_broadcast_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_broadcast_load offset must convert to one value", + rewriter); + FailureOr sourceGroupStride = getSingleValue( + op, adaptor.getSourceGroupStride(), + "group_broadcast_load source_group_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(sourceGroupStride)) + return failure(); + + VMILayoutSupport supports; + std::string supportReason; + FailureOr support = + supports.getGroupBroadcastLoadSupport(capabilities, op, &supportReason); + if (failed(support)) + return rewriter.notifyMatchFailure( + op, Twine("group_broadcast_load has no registered support: ") + + supportReason); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (support->kind == + VMIGroupBroadcastLoadSupportKind::SlotLoadThenBroadcast) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + int64_t slots = (stride && *stride == 1) ? 8 : 1; + auto sourceVMIType = VMIVRegType::get( + rewriter.getContext(), numGroups, resultVMIType.getElementType(), + VMILayoutAttr::getGroupSlots(rewriter.getContext(), numGroups, + slots)); + + FailureOr sourceArity = getVMIPhysicalArity(sourceVMIType); + FailureOr sourceElementType = + getVMIVRegPhysicalElementType(sourceVMIType); + if (failed(sourceArity) || failed(sourceElementType)) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load fallback cannot derive physical types"); + + SmallVector sourceTypes; + sourceTypes.reserve(*sourceArity); + FailureOr sourceLanesPerPart = + getDataLanesPerPart(*sourceElementType); + if (failed(sourceLanesPerPart)) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load fallback cannot derive source lanes"); + for (int64_t i = 0; i < *sourceArity; ++i) + sourceTypes.push_back(VRegType::get(rewriter.getContext(), + *sourceLanesPerPart, + *sourceElementType)); + + SmallVector sourceParts; + if (failed(lowerGroupSlotLoadParts( + op, *source, *offset, *sourceGroupStride, sourceVMIType, + sourceTypes, numGroups, rewriter, sourceParts))) + return failure(); + + SmallVector results; + if (failed(lowerGroupBroadcastParts(op, sourceParts, sourceVMIType, + resultVMIType, resultTypes, + numGroups, rewriter, results))) + return failure(); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + bool contiguousPacketLayout = layout && layout.isContiguous(); + bool splitPacketLayout = layout && layout.isDeinterleaved() && + layout.getFactor() == 2 && + layout.getBlockElems() == 1; + if (!contiguousPacketLayout && !splitPacketLayout) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires " + "contiguous result layout for direct group size or " + "deinterleaved=2, block_elems=1 result layout for split " + "group size"); + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits != 16 && elementBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires b16 or b32 " + "element type"); + int64_t directGroupSize = 256 / elementBits; + StringRef e2bDist = elementBits == 16 ? "E2B_B16" : "E2B_B32"; + + if (numGroups <= 0 || resultVMIType.getElementCount() % numGroups != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires valid num_groups"); + int64_t groupSize = resultVMIType.getElementCount() / numGroups; + if (numGroups % 8 != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires num_groups " + "multiple of 8"); + if (contiguousPacketLayout && groupSize != directGroupSize) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B contiguous lowering requires " + "element-width direct group size"); + if (splitPacketLayout && groupSize != 2 * directGroupSize) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B deinterleaved=2 lowering requires " + "element-width split group size"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires constant unit " + "source_group_stride"); + + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires !pto.ptr source"); + + FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); + if (failed(chunksPerPart) || *chunksPerPart <= 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires known chunks per part"); + int64_t factor = layout.getFactor(); + for (int64_t part = 1; part < factor; ++part) { + FailureOr currentChunks = + getDataChunksInPart(resultVMIType, part); + if (failed(currentChunks) || *currentChunks != *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires uniform chunks per part"); + } + if (static_cast(resultTypes.size()) != + factor * *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load physical arity mismatch"); + if (*chunksPerPart != numGroups / 8) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load expected one E2B packet per 8 groups in " + "each part"); + + SmallVector packets; + packets.reserve(*chunksPerPart); + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + Type packetType = resultTypes[chunk]; + auto vregType = dyn_cast(packetType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load result must be vreg"); + Value packetOffset = + createChunkOffset(op.getLoc(), *offset, chunk * 8, rewriter); + packets.push_back( + rewriter + .create(op.getLoc(), packetType, + /*updated_base=*/Type{}, *source, packetOffset, + rewriter.getStringAttr(e2bDist)) + .getResult()); + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + int64_t flatIndex = part * *chunksPerPart + chunk; + if (resultTypes[flatIndex] != resultTypes[chunk]) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B reused packet type mismatch"); + results.push_back(packets[chunk]); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + +private: + const VMITargetCapabilityRegistry &capabilities; +}; + +struct OneToNVMIStrideLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStrideLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr source = getSingleValue( + op, adaptor.getSource(), "stride_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "stride_load offset must convert to one value", + rewriter); + FailureOr blockStride = getSingleValue( + op, adaptor.getBlockStride(), + "stride_load block_stride must convert to one value", rewriter); + FailureOr repeatStride = getSingleValue( + op, adaptor.getRepeatStride(), + "stride_load repeat_stride must convert to one value", rewriter); + if (failed(source) || failed(offset) || failed(blockStride) || + failed(repeatStride)) + return failure(); + + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (resultTypes.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "stride_load supports one physical result/mask chunk"); + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType || !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "stride_load requires physical vreg/mask parts"); + + Value base = rewriter + .create(op.getLoc(), (*source).getType(), + *source, *offset) + .getResult(); + Value loaded = + rewriter + .create(op.getLoc(), resultType, base, *blockStride, + *repeatStride, maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{loaded}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIStrideStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStrideStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "stride_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "stride_store offset must convert to one value", rewriter); + FailureOr blockStride = getSingleValue( + op, adaptor.getBlockStride(), + "stride_store block_stride must convert to one value", rewriter); + FailureOr repeatStride = getSingleValue( + op, adaptor.getRepeatStride(), + "stride_store repeat_stride must convert to one value", rewriter); + if (failed(destination) || failed(offset) || failed(blockStride) || + failed(repeatStride)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "stride_store supports one physical value/mask chunk"); + if (!isa(valueParts.front().getType()) || + !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "stride_store requires physical vreg/mask parts"); + + Value base = rewriter + .create(op.getLoc(), (*destination).getType(), + *destination, *offset) + .getResult(); + rewriter.create(op.getLoc(), base.getType(), valueParts.front(), + base, *blockStride, *repeatStride, + maskParts.front()); + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "scatter destination must convert to one value", rewriter); + if (failed(destination)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != indicesParts.size() || + valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure(op, "scatter physical arity mismatch"); + + for (auto [value, indices, mask] : + llvm::zip_equal(valueParts, indicesParts, maskParts)) { + if (!isa(value.getType()) || + !isa(indices.getType()) || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "scatter physical part type mismatch"); + rewriter.create(op.getLoc(), value, *destination, indices, + mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + + +template +struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical binary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true binary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs, rhs, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIFmaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIFmaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + ValueRange accParts = adaptor.getAcc(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != accParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "fma physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, acc, resultType] : + llvm::zip_equal(lhsParts, rhsParts, accParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType || acc.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "fma requires matching physical vreg parts"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "unsupported element type for fma"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, acc, lhs, rhs, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || source.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "physical unary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true unary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask binary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask binary seed"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, lhs, rhs, *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || source.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask unary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask unary seed"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + std::optional cmpMode = getVPTOCmpMode(op.getPredicate()); + if (!cmpMode) + return op.emitOpError() + << kVMIDiagUnsupportedPrefix << "compare predicate " + << op.getPredicate() + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms " + "oeq/one/olt/ole/ogt/oge, and signed integer forms " + "slt/sle/sgt/sge"; + + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical cmp arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != rhs.getType() || + !isa(lhs.getType())) + return rewriter.notifyMatchFailure(op, + "physical cmp part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true cmp seed"); + results.push_back(rewriter + .create(op.getLoc(), resultType, lhs, rhs, + *seedMask, + rewriter.getStringAttr(*cmpMode)) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMISelectOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMISelectOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + ValueRange trueParts = adaptor.getTrueValue(); + ValueRange falseParts = adaptor.getFalseValue(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != trueParts.size() || + trueParts.size() != falseParts.size() || + trueParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical select arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [mask, trueValue, falseValue, resultType] : + llvm::zip_equal(maskParts, trueParts, falseParts, resultTypes)) { + if (!isa(mask.getType()) || trueValue.getType() != resultType || + falseValue.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical select part type mismatch"); + results.push_back(rewriter + .create(op.getLoc(), resultType, trueValue, + falseValue, mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIActivePrefixIndexOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIActivePrefixIndexOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIActivePrefixIndexOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "active_prefix_index supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires physical vreg/mask parts"); + + auto intType = dyn_cast(resultType.getElementType()); + if (!intType || !intType.isSignless()) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires signless integer result part"); + + FailureOr seedMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for active_prefix_index seed mask"); + + Value zero = rewriter.create(op.getLoc(), 0, + intType.getWidth()); + Value carrier = + rewriter + .create(op.getLoc(), resultType, zero, *seedMask, + /*position=*/nullptr) + .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, carrier, + maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != 1 || maskParts.size() != 1 || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType || sourceParts.front().getType() != resultType || + !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "compress requires physical source/mask/result parts"); + + Value result = rewriter + .create(op.getLoc(), resultType, + sourceParts.front(), maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICompressStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "compress_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "compress_store offset must convert to one value", rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress_store supports only one physical part"); + + auto valueType = dyn_cast(valueParts.front().getType()); + if (!valueType || !isa(maskParts.front().getType()) || + !isa((*destination).getType())) + return rewriter.notifyMatchFailure( + op, "compress_store requires physical value/mask and ptr " + "destination"); + + Value storeBase = + rewriter + .create(op.getLoc(), (*destination).getType(), + *destination, *offset) + .getResult(); + Value squeezed = rewriter + .create(op.getLoc(), valueType, + valueParts.front(), maskParts.front()) + .getResult(); + auto align = rewriter.create( + op.getLoc(), AlignType::get(rewriter.getContext())); + auto store = rewriter.create( + op.getLoc(), align.getResult().getType(), align.getResult(), squeezed, + storeBase, rewriter.getStringAttr("POST_UPDATE")); + rewriter.create(op.getLoc(), store.getAlignOut(), storeBase); + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIReduceAddIOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addi first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter + .create(op.getLoc(), resultType, sourcePart, maskPart) + .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIReduceAddFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addf first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter + .create(op.getLoc(), resultType, sourcePart, maskPart) + .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIGroupReduceOpPattern : OneToNOpConversionPattern { + OneToNVMIGroupReduceOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, context), + capabilities(capabilities) {} + + LogicalResult + matchAndRewrite(OpTy op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutSupport supports; + std::string supportReason; + FailureOr support = + getSupport(supports, op, &supportReason); + if (failed(support)) + return rewriter.notifyMatchFailure( + op, Twine(op->getName().getStringRef()) + + " has no layout support: " + supportReason); + + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group reduce requires num_groups to evenly divide lane count"); + + if (support->kind == VMIGroupReduceAddFSupportKind::OneVLaneVcgadd) { + if (sourceParts.size() != maskParts.size() || + sourceParts.size() != resultTypes.size() || sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires matching physical " + "arity"); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires physical vreg/mask"); + for (auto [sourcePart, maskPart, physicalResultType] : + llvm::zip_equal(sourceParts, maskParts, resultTypes)) { + if (sourcePart.getType() != resultType || + maskPart.getType() != maskType || physicalResultType != resultType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires uniform physical " + "chunk types"); + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, + maskParts[sourceIndex]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (support->kind == + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 2 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s16 block8 combine mask"); + Value loSource = sourceParts[resultIndex]; + Value hiSource = sourceParts[resultPartCount + resultIndex]; + Value loMask = maskParts[resultIndex]; + Value hiMask = maskParts[resultPartCount + resultIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + loSource.getType() != resultType || + hiSource.getType() != resultType || loMask.getType() != maskType || + hiMask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires uniform physical " + "types"); + Value lo = rewriter + .create(op.getLoc(), resultType, + loSource, loMask) + .getResult(); + Value hi = rewriter + .create(op.getLoc(), resultType, + hiSource, hiMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, lo, + hi, *combineMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (support->kind == + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 4 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s32 block8 combine mask"); + SmallVector partials; + partials.reserve(4); + for (int64_t part = 0; part < 4; ++part) { + int64_t sourceIndex = part * resultPartCount + resultIndex; + Value source = sourceParts[sourceIndex]; + Value mask = maskParts[sourceIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + source.getType() != resultType || mask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires uniform physical " + "types"); + partials.push_back(rewriter + .create( + op.getLoc(), resultType, source, mask) + .getResult()); + } + Value sum01 = + rewriter + .create(op.getLoc(), resultType, partials[0], + partials[1], *combineMask) + .getResult(); + Value sum23 = + rewriter + .create(op.getLoc(), resultType, partials[2], + partials[3], *combineMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sum01, sum23, *combineMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (support->kind != VMIGroupReduceAddFSupportKind::ContiguousVcaddRows) + return rewriter.notifyMatchFailure(op, + "unknown group_reduce_add support"); + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + if (failed(checkContiguousFullGroupChunks(op, sourceVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + bool rowLocalSlots1Result = resultLayout && resultLayout.isGroupSlots() && + resultLayout.getNumGroups() == groupCount && + resultLayout.getSlots() == 1; + int64_t expectedResultParts = + rowLocalSlots1Result ? groupCount : groupCount * chunksPerGroup; + if (sourceParts.size() != maskParts.size() || + static_cast(sourceParts.size()) != + groupCount * chunksPerGroup || + static_cast(resultTypes.size()) != expectedResultParts) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires matching source/mask/result arity"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf result must be vreg"); + FailureOr zero = createZeroVector(op.getLoc(), vregType, rewriter); + if (failed(zero)) + return rewriter.notifyMatchFailure( + op, "failed to materialize group_reduce_addf zero result"); + results.push_back(*zero); + } + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires physical vreg result and mask"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_reduce_addf masks"); + + for (int64_t group = 0; group < groupCount; ++group) { + Value accumulator; + + for (int64_t chunk = 0; chunk < chunksPerGroup; ++chunk) { + int64_t index = group * chunksPerGroup + chunk; + if (sourceParts[index].getType() != resultType || + maskParts[index].getType() != maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires uniform physical chunk types"); + Value reduced = + rewriter + .create(op.getLoc(), resultType, + sourceParts[index], maskParts[index]) + .getResult(); + FailureOr lane0Reduced = reduceVcgSlotsToLane0( + op.getLoc(), reduced, resultType, *firstLaneMask, rewriter); + if (failed(lane0Reduced)) + return rewriter.notifyMatchFailure( + op, "failed to fold group_reduce_addf VLane partials"); + reduced = *lane0Reduced; + if (!accumulator) { + accumulator = reduced; + continue; + } + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); + } + + int64_t destChunk = rowLocalSlots1Result ? group : group * chunksPerGroup; + results[destChunk] = + rewriter + .create(op.getLoc(), resultType, accumulator, + results[destChunk], *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + +private: + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddFOp op, + std::string *reason) const { + return supports.getGroupReduceAddFSupport(capabilities, op, reason); + } + + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddIOp op, + std::string *reason) const { + return supports.getGroupReduceAddISupport(capabilities, op, reason); + } + + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceMaxIOp op, + std::string *reason) const { + return supports.getGroupReduceMaxISupport(capabilities, op, reason); + } + + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceMaxFOp op, + std::string *reason) const { + return supports.getGroupReduceMaxFSupport(capabilities, op, reason); + } + + const VMITargetCapabilityRegistry &capabilities; +}; + +struct OneToNVMIGroupBroadcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupBroadcastOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, + "group_broadcast requires num_groups to evenly divide lane count"); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + if (failed(checkFullGroupSlotSourceShape( + op, sourceVMIType, *groupSize, op.getNumGroupsAttr().getInt(), + &lanesPerPart, &groupCount, rewriter))) + return failure(); + int64_t resultLayoutFactor = 0; + int64_t resultGroupCount = 0; + if (failed(checkFullGroupBroadcastResultShape( + op, resultVMIType, *groupSize, lanesPerPart, &resultLayoutFactor, + &resultGroupCount, rewriter))) + return failure(); + if (resultGroupCount != groupCount) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires matching source/result group slots"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "group_broadcast arity mismatch"); + + auto firstSourceType = dyn_cast(sourceParts.front().getType()); + if (!firstSourceType) + return rewriter.notifyMatchFailure(op, + "group_broadcast source must be vreg"); + unsigned indexBits = + pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires 8/16/32-bit index elements"); + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = + VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), + indexElementType); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstSourceType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast all mask"); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t selectionGroupSize = *groupSize; + if (resultLayoutFactor != 1 && resultLayout && + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < lanesPerPart) + selectionGroupSize = resultLayout.getBlockElems(); + auto resolveLargeGroupSource = [&](int64_t group, int64_t chunksPerGroup, + int64_t &sourceChunk, + int64_t &baseGroupSlot) { + int64_t slots = sourceLayout.getSlots(); + if (slots > 0) { + sourceChunk = group / slots; + baseGroupSlot = group % slots; + return; + } + sourceChunk = group * chunksPerGroup; + baseGroupSlot = 0; + }; + + SmallVector results; + results.resize(resultTypes.size()); + for (auto [flatIndex, resultType] : llvm::enumerate(resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires uniform physical vreg types"); + int64_t sourceChunk = flatIndex; + int64_t baseGroupSlot = 0; + Value mappedGroupSlotIndex; + if (resultLayoutFactor == 1) { + if (*groupSize >= lanesPerPart) { + int64_t chunksPerGroup = *groupSize / lanesPerPart; + int64_t group = flatIndex / chunksPerGroup; + resolveLargeGroupSource(group, chunksPerGroup, sourceChunk, + baseGroupSlot); + } else { + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + int64_t groupsPerResultChunk = lanesPerPart / *groupSize; + int64_t firstGroup = flatIndex * groupsPerResultChunk; + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + } + } else { + bool blockFragmentSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + bool deinterleavedSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1 && *groupSize < lanesPerPart; + if (blockFragmentSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t groupsPerResultChunk = + lanesPerPart / resultLayout.getBlockElems(); + int64_t firstGroup = chunk * groupsPerResultChunk; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, + "group_broadcast block-fragment source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else if (deinterleavedSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast deinterleaved small-group source " + "requires explicit group_slots slots or derivable " + "legacy slot count"); + slots = groupCount / sourceParts.size(); + } + FailureOr index = createMappedGroupSlotIndexVector( + op.getLoc(), resultVMIType, part, chunk, indexType, + *groupSize, slots, sourceChunk, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, + "failed to create group_broadcast mapped group-slot index " + "vector"); + mappedGroupSlotIndex = *index; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + resolveLargeGroupSource(firstGroup, chunksPerGroup, sourceChunk, + baseGroupSlot); + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } + } + if (*groupSize >= lanesPerPart) { + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + if (sourceLayout.getSlots() > 1) { + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } else { + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) + .getResult(); + } + } else { + bool blockFragmentSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1; + bool deinterleavedSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup && + !deinterleavedSmallGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group deinterleaved result is not " + "supported"); + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + FailureOr groupSlotIndex = + mappedGroupSlotIndex + ? FailureOr(mappedGroupSlotIndex) + : createGroupSlotIndexVector(op.getLoc(), indexType, + selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIDhistOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange accParts = adaptor.getAcc(); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + if (accParts.size() != 2 || sourceParts.empty() || + sourceParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "expected two accumulator parts and matching source/mask chunks"); + + auto loType = dyn_cast(accParts[0].getType()); + auto hiType = dyn_cast(accParts[1].getType()); + if (!loType || loType != hiType) + return rewriter.notifyMatchFailure(op, + "expected matching ui16 acc parts"); + auto sourceType = cast(op.getSource().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure(op, "failed to compute source lanes"); + + Location loc = op.getLoc(); + Value bin0 = createI32Constant(loc, 0, rewriter); + Value bin1 = createI32Constant(loc, 1, rewriter); + Value lo = accParts[0]; + Value hi = accParts[1]; + + for (size_t index = 0, e = sourceParts.size(); index < e; ++index) { + Value source = sourceParts[index]; + Value userMask = maskParts[index]; + auto maskType = dyn_cast(userMask.getType()); + if (!maskType || !maskType.isB8()) + return rewriter.notifyMatchFailure(op, "expected b8 source mask"); + + Value chunkMask = userMask; + int64_t firstLane = static_cast(index) * *lanesPerPart; + int64_t activeLanes = std::min( + *lanesPerPart, sourceType.getElementCount() - firstLane); + if (activeLanes < *lanesPerPart) { + FailureOr validMask = createPrefixMaskForActiveLanes( + loc, maskType, activeLanes, rewriter); + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(validMask) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize tail-valid b8 mask"); + chunkMask = + rewriter + .create(loc, maskType, chunkMask, *validMask, *allMask) + .getResult(); + } + + lo = rewriter.create(loc, loType, lo, source, chunkMask, bin0) + .getResult(); + hi = rewriter.create(loc, hiType, hi, source, chunkMask, bin1) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{lo, hi}, + adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching source/mask chunks " + "and one init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching physical source/" + "init/result vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every source chunk to " + "match result vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every mask chunk to have " + "the same predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create floating min/max reduction first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = rewriter + .create(op.getLoc(), resultType, + sourcePart, maskPart) + .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExtFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "extf requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected physical extf source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "extf source physical parts must have matching type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + (resultVRegTypes.empty() ? !resultVRegType.getElementType().isF32() + : resultVRegType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical extf result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + ((sourceBits == 16 && sourceLayout.getLaneStride() == 2) || + (sourceBits == 8 && sourceLayout.getLaneStride() == 4)) && + resultTypes.size() == sourceParts.size()) { + StringRef part = sourceBits == 16 ? StringRef("EVEN") : StringRef("P0"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build extf seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(part)) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ArrayRef parts; + int64_t factor = 0; + if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (sourceBits == 8 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical extf source/result width relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build extf seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceVMIType.getElementType().isF32() || + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr even = rewriter.getStringAttr("EVEN"); + FailureOr lane0Mask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", + rewriter); + if (failed(lane0Mask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot truncf lane0 mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || !sourceType.getElementType().isF32() || + !resultType || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf physical type"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *lane0Mask, rnd, sat, + even) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "truncf requires result chunks"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + if (!sourceType0 || !sourceType0.getElementType().isF32()) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "truncf source physical parts must have matching f32 type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type physicalResultType : resultTypes) { + auto resultType = dyn_cast(physicalResultType); + if (!resultType || + (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( + resultType.getElementType()) == 0 + : resultType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf result type"); + resultVRegTypes.push_back(resultType); + } + + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultVRegTypes.front().getElementType()); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isContiguous() && + resultLayout.getLaneStride() != 1 && + sourceParts.size() == resultTypes.size()) { + StringRef part; + if (resultBits == 16 && resultLayout.getLaneStride() == 2) + part = "EVEN"; + else if (resultBits == 8 && resultLayout.getLaneStride() == 4) + part = "P0"; + else + return rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride truncf result layout"); + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + if (failed(sourceMask)) + return rewriter.notifyMatchFailure(op, + "failed to build truncf masks"); + + StringAttr rnd = rewriter.getStringAttr( + getTruncFRoundMode(op, resultVRegTypes.front().getElementType())); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr partAttr = rewriter.getStringAttr(part); + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, rnd, sat, partAttr) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ArrayRef allParts; + int64_t factor = 0; + if (resultBits == 16) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + allParts = kEvenOddParts; + factor = 2; + } else if (resultBits == 8) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + allParts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result width relation"); + } + + int64_t resultLaneStride = + resultLayout && resultLayout.isContiguous() ? resultLayout.getLaneStride() + : 1; + if (resultLaneStride <= 0 || factor % resultLaneStride != 0) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf result lane stride"); + int64_t sourceFactor = factor / resultLaneStride; + if (sourceParts.size() != sourceFactor * resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result arity relation"); + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + if (failed(sourceMask)) + return rewriter.notifyMatchFailure(op, "failed to build truncf masks"); + + StringAttr rnd = rewriter.getStringAttr( + getTruncFRoundMode(op, resultVRegTypes.front().getElementType())); + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [chunkIndex, resultType] : llvm::enumerate(resultVRegTypes)) { + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(resultMask)) + return rewriter.notifyMatchFailure( + op, "failed to build truncf result mask"); + + SmallVector partials; + partials.reserve(sourceFactor); + for (int64_t partIndex = 0; partIndex < sourceFactor; ++partIndex) { + Value sourcePart = + sourceParts[partIndex * resultTypes.size() + chunkIndex]; + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, rnd, sat, + rewriter.getStringAttr( + allParts[partIndex * resultLaneStride])) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + results.push_back(merged); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(OpT op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "integer extension requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure( + op, "expected physical integer extension source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "integer extension source physical parts must have matching " + "type"); + } + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 8 || resultLayout.getSlots() != 8 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot integer extension shape"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if ((sourceBits != 8 && sourceBits != 16) || resultBits != 32) + return rewriter.notifyMatchFailure( + op, "group-slot integer extension requires 8/16-bit source and " + "32-bit result element widths"); + + FailureOr maskType = + getMaskTypeForVReg(sourceType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "failed to create group-slot integer extension mask type"); + FailureOr slotMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, sourceLayout.getSlots(), rewriter); + if (failed(slotMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension mask"); + + SmallVector partNames; + int64_t partFactor = 0; + if (sourceBits == 16) { + partNames.assign({"EVEN", "ODD"}); + partFactor = 2; + } else { + partNames.assign({"P0", "P1", "P2", "P3"}); + partFactor = 4; + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [chunkIndex, sourcePart, resultType] : + llvm::enumerate(sourceParts, resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || pto::getPTOStorageElemBitWidth( + resultVRegType.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot integer extension result type"); + + SmallVector convertedParts; + convertedParts.reserve(partNames.size()); + for (StringRef partName : partNames) { + convertedParts.push_back( + rewriter + .create(op.getLoc(), resultVRegType, sourcePart, + *slotMask, /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(partName)) + .getResult()); + } + + FailureOr resultMaskType = + getMaskTypeForVReg(resultVRegType, rewriter.getContext()); + FailureOr resultAllMask = + createAllTrueMaskForVReg(op.getLoc(), resultVRegType, rewriter); + if (failed(resultMaskType) || failed(resultAllMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension result seed"); + + auto indexType = VRegType::get( + rewriter.getContext(), resultVRegType.getElementCount(), + IntegerType::get(rewriter.getContext(), 32)); + int64_t groupBegin = + static_cast(chunkIndex) * sourceLayout.getSlots(); + int64_t activeSlots = std::min( + sourceLayout.getSlots(), sourceLayout.getNumGroups() - groupBegin); + if (activeSlots <= 0) + return rewriter.notifyMatchFailure( + op, "group-slot integer extension has no active slots"); + Value assembled; + for (int64_t slot = 0; slot < activeSlots; ++slot) { + int64_t partIndex = slot % partFactor; + int64_t sourceLane = slot / partFactor; + FailureOr laneIndexScalar = createScalarOffsetConstant( + op.getLoc(), indexType.getElementType(), sourceLane, rewriter); + FailureOr laneMask = createLaneRangeMask( + op.getLoc(), *resultMaskType, slot, slot + 1, rewriter); + if (failed(laneIndexScalar) || failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension slot mask"); + Value laneIndex = + rewriter + .create(op.getLoc(), indexType, *laneIndexScalar, + *resultAllMask, /*position=*/nullptr) + .getResult(); + Value selected = + rewriter + .create(op.getLoc(), resultVRegType, + convertedParts[partIndex], laneIndex) + .getResult(); + if (!assembled) { + assembled = selected; + continue; + } + assembled = rewriter + .create(op.getLoc(), resultVRegType, selected, + assembled, *laneMask) + .getResult(); + } + + results.push_back(assembled); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + !isa(resultVRegType.getElementType()) || + (!resultVRegTypes.empty() && + resultVRegType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultVRegTypes.front().getElementType()); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + ((resultBits == sourceBits * 2 && + sourceLayout.getLaneStride() == 2) || + (resultBits == sourceBits * 4 && + sourceLayout.getLaneStride() == 4)) && + resultTypes.size() == sourceParts.size()) { + StringRef part = + resultBits == sourceBits * 2 ? StringRef("EVEN") : StringRef("P0"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to build integer extension seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(part)) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ArrayRef parts; + int64_t factor = 0; + if (resultBits == sourceBits * 2 && + resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (resultBits == sourceBits * 4 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension source/result width " + "relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to build integer extension seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != resultLayout.getSlots() || + (sourceLayout.getSlots() != 1 && sourceLayout.getSlots() != 8) || + pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()) != + 32 || + (pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 && + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 8) || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr sat = rewriter.getStringAttr("SAT"); + const char *activeSlotPattern = + sourceLayout.getSlots() == 1 ? "PAT_VL1" : "PAT_VL8"; + FailureOr activeSlotMask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), + activeSlotPattern, rewriter); + if (failed(activeSlotMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot trunci active slot mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || + !resultType) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci physical type"); + + unsigned physicalResultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (resultLayout.hasLaneStride() && + resultLayout.getLaneStride() == 4 && + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) == + 8 && + physicalResultBits == 32) { + if (sourcePart.getType() == resultType) { + results.push_back(sourcePart); + } else { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart) + .getResult()); + } + continue; + } + + if (physicalResultBits != 16 && physicalResultBits != 8) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci physical type"); + + StringAttr part = + physicalResultBits == 16 + ? rewriter.getStringAttr("EVEN") + : rewriter.getStringAttr("P0"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *activeSlotMask, + /*rnd=*/nullptr, sat, part) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure( + op, "trunci requires non-empty physical source and result parts"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + auto resultType0 = dyn_cast(resultTypes.front()); + if (!sourceType0 || !isa(sourceType0.getElementType()) || + !resultType0 || !isa(resultType0.getElementType())) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "trunci source physical parts must have matching integer type"); + } + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != resultType0) + return rewriter.notifyMatchFailure( + op, "trunci result physical parts must have matching integer type"); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType0.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType0.getElementType()); + if (sourceBits == 0 || resultBits == 0 || sourceBits % resultBits != 0) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result width relation"); + int64_t factor = sourceBits / resultBits; + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isContiguous() && + resultLayout.getLaneStride() == factor && + sourceParts.size() == resultTypes.size()) { + if (factor != 2 && factor != 4) + return rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride trunci result layout"); + StringAttr part = rewriter.getStringAttr(factor == 2 ? "EVEN" : "P0"); + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + if (failed(sourceMask)) + return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); + + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, /*rnd=*/nullptr, sat, part) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if ((factor != 2 && factor != 4) || + sourceParts.size() != resultTypes.size() * factor) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result arity relation"); + + ArrayRef parts; + if (factor == 2) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + } else if (factor == 4) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result width relation"); + } + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType0, rewriter); + if (failed(sourceMask) || failed(resultMask)) + return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); + + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t resultIndex = 0, resultCount = resultTypes.size(); + resultIndex < resultCount; ++resultIndex) { + Type resultType = resultTypes[resultIndex]; + SmallVector partials; + partials.reserve(parts.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + Value sourcePart = sourceParts[resultIndex * factor + partIndex]; + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, /*rnd=*/nullptr, sat, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + results.push_back(merged); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIFPToSIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIFPToSIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "fptosi physical source/result arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceType || !sourceType.getElementType().isF32() || + !resultVRegType || + !isa(resultVRegType.getElementType()) || + pto::getPTOStorageElemBitWidth(resultVRegType.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "fptosi requires physical f32 source and 32-bit integer " + "result chunks"); + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build fptosi mask"); + results.push_back(rewriter + .create(op.getLoc(), resultVRegType, + sourcePart, *mask, rnd, sat, + /*part=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMISIToFPOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMISIToFPOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "sitofp physical source/result arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceType || !isa(sourceType.getElementType()) || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || + !resultVRegType || !resultVRegType.getElementType().isF32()) + return rewriter.notifyMatchFailure( + op, "sitofp requires physical 32-bit integer source and f32 " + "result chunks"); + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build sitofp mask"); + results.push_back(rewriter + .create(op.getLoc(), resultVRegType, + sourcePart, *mask, rnd, + /*sat=*/nullptr, /*part=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBitcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical bitcast arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + if (!isa(sourcePart.getType()) || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical bitcast part type mismatch"); + results.push_back( + rewriter.create(op.getLoc(), resultType, sourcePart) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelSplitOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelSplitOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getNumResults(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_split only supports 2 or 4 channels"); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!sourceLayout || + (!sourceLayout.isContiguous() && sourceLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_split requires contiguous or matching deinterleaved source " + "layout"); + for (Value result : op.getResults()) { + auto resultType = cast(result.getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_split requires contiguous result layouts"); + } + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(); + FailureOr> results = + materializeDataLayoutConversion(op, adaptor.getSource(), resultTypes, + sourceLayout, channelLayout, rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelMergeOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelMergeOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getInputs().size(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_merge only supports 2 or 4 channels"); + + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_merge requires contiguous input layouts"); + } + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!resultLayout || + (!resultLayout.isContiguous() && resultLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_merge requires contiguous or matching deinterleaved result " + "layout"); + + FailureOr> results = materializeDataLayoutConversion( + op, adaptor.getFlatOperands(), + adaptor.getResultMapping().getConvertedTypes(0), channelLayout, + resultLayout, rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIShuffleOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> sourceFlatIndices = + computeShuffleForwardingSourceParts(op, &reason); + if (succeeded(sourceFlatIndices)) { + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t sourceFlatIndex : *sourceFlatIndices) { + if (sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle forwarding source part range is out of bounds"); + results.push_back(sourceParts[sourceFlatIndex]); + } + + if (failed( + verifyIdentityPartForwarding(op, results, resultTypes, rewriter))) + return failure(); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string splatReason; + FailureOr splatSource = + computeShuffleLane0SplatSourcePart(op, &splatReason); + if (succeeded(splatSource)) { + if (*splatSource >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat source part range is out of bounds"); + + SmallVector results; + results.reserve(resultTypes.size()); + Value sourcePart = sourceParts[*splatSource]; + for (Type resultType : resultTypes) { + auto sourceVRegType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType != resultVRegType) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat requires matching physical vreg type"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), resultVRegType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create shuffle lane0 splat mask"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *mask, + rewriter.getStringAttr("LOWEST")) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string vselrReason; + FailureOr> vselrPlans = + computeShuffleVselrPlans(op, &vselrReason); + if (failed(vselrPlans)) + return rewriter.notifyMatchFailure(op, + Twine("shuffle vselr ") + vselrReason); + + if (vselrPlans->size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "shuffle vselr arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [plan, resultType] : llvm::zip_equal(*vselrPlans, resultTypes)) { + if (plan.sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source part range is out of bounds"); + + auto sourceVRegType = + dyn_cast(sourceParts[plan.sourceFlatIndex].getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType.getElementCount() != + resultVRegType.getElementCount() || + sourceVRegType.getElementType() != resultVRegType.getElementType()) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source/result type mismatch"); + + unsigned indexBits = + pto::getPTOStorageElemBitWidth(sourceVRegType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "shuffle vselr requires 8/16/32-bit index elements"); + + auto indexElementType = + IntegerType::get(rewriter.getContext(), indexBits); + Type indexType = + VRegType::get(rewriter.getContext(), sourceVRegType.getElementCount(), + indexElementType); + FailureOr base = createScalarOffsetConstant( + op.getLoc(), indexElementType, plan.baseLane, rewriter); + if (failed(base)) + return rewriter.notifyMatchFailure( + op, "failed to materialize shuffle vselr index base"); + StringAttr orderAttr = + plan.descending ? rewriter.getStringAttr("DESC") : StringAttr{}; + Value indexVector = + rewriter.create(op.getLoc(), indexType, *base, orderAttr) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourceParts[plan.sourceFlatIndex], + indexVector) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +Block *convertBranchDestBlock(Block *block, OneToNPatternRewriter &rewriter, + OneToNTypeConverter &typeConverter, + llvm::DenseMap &converted) { + auto [it, inserted] = converted.try_emplace(block, nullptr); + if (!inserted) + return it->second; + + OneToNTypeMapping argMapping(block->getArgumentTypes()); + if (failed(typeConverter.computeTypeMapping(block->getArgumentTypes(), + argMapping)) || + !argMapping.hasNonIdentityConversion()) { + it->second = block; + return block; + } + + Block *newBlock = rewriter.applySignatureConversion(block, argMapping); + it->second = newBlock; + return newBlock; +} + +struct OneToNCFBranchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *dest = convertBranchDestBlock(op.getDest(), rewriter, *converter, + convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + dest == op.getDest()) + return failure(); + + rewriter.replaceOpWithNewOp(op, dest, + adaptor.getFlatOperands()); + return success(); + } +}; + +struct OneToNCFCondBranchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *trueDest = convertBranchDestBlock(op.getTrueDest(), rewriter, + *converter, convertedBlocks); + Block *falseDest = convertBranchDestBlock(op.getFalseDest(), rewriter, + *converter, convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + trueDest == op.getTrueDest() && falseDest == op.getFalseDest()) + return failure(); + + ValueRange condition = adaptor.getCondition(); + if (condition.size() != 1) + return rewriter.notifyMatchFailure( + op, "condition converted to multiple values"); + + SmallVector trueOperands; + SmallVector falseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getNumTrueOperands(); i < e; ++i) + llvm::append_range(trueOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); + for (unsigned i = 0, e = op.getNumFalseOperands(); i < e; ++i) + llvm::append_range(falseOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); + + rewriter.replaceOpWithNewOp(op, condition.front(), + trueDest, trueOperands, + falseDest, falseOperands); + return success(); + } +}; + +struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *defaultDest = convertBranchDestBlock( + op.getDefaultDestination(), rewriter, *converter, convertedBlocks); + + SmallVector caseDests; + caseDests.reserve(op.getCaseDestinations().size()); + for (Block *dest : op.getCaseDestinations()) + caseDests.push_back( + convertBranchDestBlock(dest, rewriter, *converter, convertedBlocks)); + + bool changed = defaultDest != op.getDefaultDestination(); + for (auto [oldDest, newDest] : + llvm::zip(op.getCaseDestinations(), caseDests)) + changed |= oldDest != newDest; + changed |= adaptor.getOperandMapping().hasNonIdentityConversion(); + if (!changed) + return failure(); + + ValueRange flag = adaptor.getFlag(); + if (flag.size() != 1) + return rewriter.notifyMatchFailure(op, + "flag converted to multiple values"); + + SmallVector defaultOperands; + SmallVector> caseOperandStorage; + SmallVector caseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getDefaultOperands().size(); i < e; ++i) + llvm::append_range(defaultOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); + + caseOperandStorage.reserve(op.getCaseOperandSegments().size()); + caseOperands.reserve(op.getCaseOperandSegments().size()); + for (int32_t segmentSize : op.getCaseOperandSegments()) { + SmallVector operands; + for (int32_t i = 0; i < segmentSize; ++i) + llvm::append_range(operands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); + caseOperandStorage.push_back(std::move(operands)); + } + for (SmallVector &operands : caseOperandStorage) + caseOperands.push_back(operands); + + rewriter.replaceOpWithNewOp( + op, flag.front(), defaultDest, defaultOperands, op.getCaseValuesAttr(), + caseDests, caseOperands); + return success(); + } +}; + +struct OneToNSCFExecuteRegionOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + scf::ExecuteRegionOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ExecuteRegionOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = + rewriter.create(op.getLoc(), resultTypes); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +struct OneToNSCFIndexSwitchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + scf::IndexSwitchOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IndexSwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange arg = adaptor.getArg(); + if (arg.size() != 1) + return rewriter.notifyMatchFailure( + op, "index_switch selector converted to multiple values"); + + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), resultTypes, arg.front(), op.getCases(), op.getNumCases()); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getDefaultRegion(), newOp.getDefaultRegion(), + newOp.getDefaultRegion().end()); + for (auto [srcRegion, dstRegion] : + llvm::zip(op.getCaseRegions(), newOp.getCaseRegions())) + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +void populateVMIOneToNConversionPatterns( + VMIToVPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + const VMITargetCapabilityRegistry &capabilities) { + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); + patterns.add< + OneToNVMIEnsureLayoutOpPattern, OneToNVMIEnsureMaskLayoutOpPattern, + OneToNVMIBroadcastOpPattern, OneToNVMIIotaOpPattern, + OneToNVMIConstantOpPattern, OneToNVMIConstantMaskOpPattern, + OneToNVMICreateMaskOpPattern, OneToNVMICreateGroupMaskOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, + OneToNVMIDeinterleaveLoadOpPattern, OneToNVMIGroupLoadOpPattern, + OneToNVMIGroupSlotLoadOpPattern, OneToNVMIStrideLoadOpPattern, + OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, + OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, + OneToNVMIInterleaveStoreOpPattern, OneToNVMIGroupStoreOpPattern, + OneToNVMIMaskedStoreOpPattern, OneToNVMIStrideStoreOpPattern, + OneToNVMIScatterOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, OneToNVMIFmaOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMICmpOpPattern, OneToNVMICmpOpPattern, + OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, + OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, + OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIDhistOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, + OneToNVMIExtIOpPattern, OneToNVMIExtIOpPattern, + OneToNVMITruncIOpPattern, OneToNVMIFPToSIOpPattern, + OneToNVMISIToFPOpPattern, OneToNVMIBitcastOpPattern, + OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, + OneToNVMIShuffleOpPattern>(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), capabilities); + patterns.add< + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern>( + typeConverter, patterns.getContext(), capabilities); + patterns.add( + typeConverter, patterns.getContext(), capabilities); +} + +LogicalResult verifyNoResidualVMIIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (isa(op)) { + op->emitError() << kVMIDiagResidualOpPrefix + << "unrealized conversion cast remains after vmi-to-vpto"; + return WalkResult::interrupt(); + } + if (auto createMask = dyn_cast(op)) { + if (!createMask.getActiveLanes().getDefiningOp()) { + createMask.emitError() + << kVMIDiagUnsupportedPrefix + << "dynamic pto.vmi.create_mask active_lanes could not be lowered " + "by the current runtime predicate generation plan"; + return WalkResult::interrupt(); + } + } + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + } + if (isVMIOp(op) || hasVMIType(op)) { + op->emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult checkSupportedExtFShape(VMIExtFOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (failed(supports.getExtFSupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedTruncFShape(VMITruncFOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (failed(supports.getTruncFSupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedExtSIShape(VMIExtSIOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (failed(supports.getExtSISupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedExtUIShape(VMIExtUIOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (failed(supports.getExtUISupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedTruncIShape(VMITruncIOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (failed(supports.getTruncISupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedFPToSIShape(VMIFPToSIOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout != resultLayout) + return fail("requires source/result layouts to match"); + if (!sourceType.getElementType().isF32()) + return fail("requires f32 source element type"); + if (!isa(resultType.getElementType()) || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("requires 32-bit integer result element type"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity) || + *sourceArity != *resultArity) + return fail("requires matching computable physical arity"); + return success(); +} + +LogicalResult checkSupportedSIToFPShape(VMISIToFPOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout != resultLayout) + return fail("requires source/result layouts to match"); + if (!isa(sourceType.getElementType()) || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32) + return fail("requires 32-bit integer source element type"); + if (!resultType.getElementType().isF32()) + return fail("requires f32 result element type"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity) || + *sourceArity != *resultArity) + return fail("requires matching computable physical arity"); + return success(); +} + +LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { + VMILayoutSupport supports; + if (failed(supports.getBitcastSupport(op, reason))) + return failure(); + return success(); +} + +LogicalResult +checkSupportedChannelSplitShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelSplitOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getNumResults(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_split", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return fail("requires assigned source layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return fail("requires source layout to be contiguous or matching " + "deinterleaved channel layout"); + + for (Value result : op.getResults()) { + VMILayoutAttr resultLayout = + cast(result.getType()).getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return fail("requires every result layout to be contiguous"); + } + + auto channelType = + VMIVRegType::get(op.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), expectedLayout); + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, sourceType, channelType, sourceLayout, expectedLayout, + &materializationReason))) + return fail(Twine("cannot materialize source to channel layout; ") + + materializationReason); + + FailureOr channelArity = getVMIPhysicalArity(channelType); + int64_t resultArity = 0; + for (Value result : op.getResults()) { + FailureOr arity = + getVMIPhysicalArity(cast(result.getType())); + if (failed(arity)) + return fail("requires computable result physical arity"); + resultArity += *arity; + } + if (failed(channelArity) || *channelArity != resultArity) + return fail("requires channel physical arity to match all result parts"); + + return success(); +} + +LogicalResult +checkSupportedChannelMergeShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelMergeOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getInputs().size(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_merge", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + int64_t inputArity = 0; + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return fail("requires every input layout to be contiguous"); + FailureOr arity = getVMIPhysicalArity(inputType); + if (failed(arity)) + return fail("requires computable input physical arity"); + inputArity += *arity; + } + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return fail("requires result layout to be contiguous or matching " + "deinterleaved channel layout"); + + auto channelType = + VMIVRegType::get(op.getContext(), resultType.getElementCount(), + resultType.getElementType(), expectedLayout); + FailureOr channelArity = getVMIPhysicalArity(channelType); + if (failed(channelArity) || *channelArity != inputArity) + return fail("requires channel physical arity to match all input parts"); + + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, channelType, resultType, expectedLayout, resultLayout, + &materializationReason))) + return fail(Twine("cannot materialize channel layout to result; ") + + materializationReason); + + return success(); +} + +LogicalResult +checkSupportedActivePrefixIndexShape(VMIActivePrefixIndexOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!maskLayout || !resultLayout) + return fail("requires assigned mask and result layouts"); + if (!maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous mask and result layouts"); + + std::string resultFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultFullReason))) + return fail(Twine("requires full result physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + resultFullReason); + + std::string maskFullReason; + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("requires full mask physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + maskFullReason); + + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(maskArity) || failed(resultArity)) + return fail("requires computable mask and result physical arity"); + if (*maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk prefix needs " + "cross-chunk carry"); + + return success(); +} + +LogicalResult checkSupportedCompressShape(VMICompressOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous() || + !resultLayout.isContiguous()) + return fail("requires contiguous source, mask, and result layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding mask " + "lanes cannot be squeezed into the result; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source, mask, and result physical arity"); + if (*sourceArity != 1 || *maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk compress needs " + "cross-chunk compaction"); + + return success(); +} + +LogicalResult checkSupportedCompressStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMICompressStoreOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + if (!valueLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous value and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vstur", + "pto.vstur stores only to UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(valueType, &fullChunkReason))) + return fail(Twine("requires full physical chunks so padding mask lanes " + "cannot be squeezed into memory; ") + + fullChunkReason); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity)) + return fail("requires computable value and mask physical arity"); + if (*valueArity != 1 || *maskArity != 1) + return fail("requires a single physical chunk; multi-chunk " + "compress_store needs cross-chunk compaction and SQZN " + "state planning"); + + return success(); +} + +template +LogicalResult +checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, VMIReductionKind kind, bool requiresReassoc, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (requiresReassoc && !op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point vcadd"); + + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr initLayout = initType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !initLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, init, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !initLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous source, init, mask, and result layouts"); + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(kind, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding lanes " + "do not participate in the reduction; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr initArity = getVMIPhysicalArity(initType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(initArity) || failed(maskArity) || + failed(resultArity)) + return fail("requires computable physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires source and mask physical arity to match and be " + "non-empty"); + if (*initArity != 1 || *resultArity != 1) + return fail("requires one init and result physical chunk"); + + return success(); +} + +template +LogicalResult +checkSupportedGroupReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, std::string *reason = nullptr) { + VMILayoutSupport supports; + if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceAddFSupport(capabilities, op, reason))) + return success(); + } else if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceMaxFSupport(capabilities, op, reason))) + return success(); + } else if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceMaxISupport(capabilities, op, reason))) + return success(); + } else { + if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) + return success(); + } + return failure(); +} + +LogicalResult checkSupportedGroupBroadcastShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason = nullptr) { + (void)capabilities; + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType()) { + if (reason) + *reason = "requires source/result element type to match"; + return failure(); + } + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + VMILayoutSupport supports; + if (succeeded(supports.getGroupBroadcastSupport(capabilities, op, nullptr))) + return success(); + if (sourceType.getElementCount() != op.getNumGroupsAttr().getInt()) + return fail("requires source lane count to match num_groups"); + if (resultType.getElementCount() % op.getNumGroupsAttr().getInt() != 0) + return fail("requires num_groups to evenly divide result lane count"); + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return success(); + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return success(); + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + return success(); +} + +LogicalResult checkSupportedDhistShape(VMIDhistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getDhistSupport(op, reason))) + return success(); + return failure(); +} + +LogicalResult checkSupportedChistShape(VMIChistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getChistSupport(op, reason))) + return success(); + return failure(); +} + +LogicalResult +checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, + VMIFmaOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lhsType = cast(op.getLhs().getType()); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + lhsType.getElementType(), VMIElementPurpose::VMula); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(lhsType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical arity"); + + return success(); +} + +LogicalResult +checkSupportedReluShape(const VMITargetCapabilityRegistry &capabilities, + VMIReluOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + if (failed(checkSupportedMaskableVReg(capabilities, resultType, reason))) + return failure(); + + VMICapabilityResult elementCapability = capabilities.supportsElementType( + resultType.getElementType(), VMIElementPurpose::VRelu); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +void emitEnsureLayoutMaterializationError(VMIEnsureLayoutOp ensure, + VMIVRegType sourceType, + VMIVRegType resultType, + StringRef reason) { + if (ensure.getResult().hasOneUse()) { + OpOperand &use = *ensure.getResult().use_begin(); + Operation *requester = use.getOwner(); + InFlightDiagnostic diag = + requester->emitError() + << kVMIDiagUnsupportedPrefix << requester->getName() << " operand #" + << use.getOperandNumber() << " has type " << sourceType + << " but requires " << resultType + << "; pto.vmi.ensure_layout cannot materialize this conversion"; + diag.attachNote(ensure.getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; + return; + } + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_layout cannot materialize the requested data " + "layout conversion (" + << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; +} + +LogicalResult +verifySupportedVMIToVPTOOps(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities, + bool enableStableGatherMaskedLoad) { + auto emitMemoryUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, Value source, + std::optional constantOffset) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedLoadShape(capabilities, type, source, + source.getType(), constantOffset, + &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires a supported memory source (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitMaskableUnsupported = [&](Operation *op, StringRef opName, + VMIVRegType type) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedMaskableVReg(capabilities, type, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires physical vreg parts with b8/b16/b32 " + "predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitTargetElementUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, + VMIElementPurpose purpose, StringRef elementContract) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedTargetElementVReg(capabilities, type, purpose, + elementContract, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName << " direct lowering requires " + << elementContract + << " and physical vreg parts with b8/b16/b32 predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + WalkResult result = module.walk([&](Operation *op) { + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (!denseAttr || !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + return emitMaskableUnsupported( + op, "pto.vmi.constant", + cast(constant.getResult().getType())); + } + + if (auto broadcast = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.broadcast", + cast(broadcast.getResult().getType())); + if (auto broadcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupBroadcastShape(capabilities, broadcast, + &reason))) + return WalkResult::advance(); + broadcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_broadcast requires full source chunks with " + "#pto.vmi.layout, a dense full result " + "layout, " + "and num_groups deriving a group size that divides or is a " + "multiple of physical chunk lanes (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedDhistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.dhist requires contiguous Nxui8 source, contiguous b8 " + "mask, and contiguous 256xui16 acc/result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedChistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.chist requires a verified CHISTv2 range semantics " + "contract before lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto load = dyn_cast(op)) { + return emitMemoryUnsupported( + op, "pto.vmi.load", cast(load.getResult().getType()), + load.getSource(), getConstantIndexValue(load.getOffset())); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedDeinterleaveLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.deinterleave_load lowers through pto.vldsx2 only for " + "matching contiguous full low/high result chunks with a supported " + "UB source and 8/16/32-bit element type (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStrideLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.stride_load lowers through pto.vsldb only for one " + "contiguous physical result/mask chunk and a supported UB source (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_load requires contiguous full result chunks, a " + "supported UB source, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupSlotLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_slot_load requires explicit group_slots result " + "layout matching num_groups, a supported UB pointer source, " + "and either slots=8 with constant unit source_group_stride or " + "slots=1 row-local lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupBroadcastLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_broadcast_load requires either the E2B packet " + "form for b16/b32 direct or split group size, or the generic " + "group-slot-load then group-broadcast fallback with supported UB " + "pointer source and source_group_stride (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + if (enableStableGatherMaskedLoad) { + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load stable VGATHER-based lowering is reserved " + "for strict masked/tail loads but is not implemented yet"; + return WalkResult::interrupt(); + } + std::string reason; + if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load direct lowering requires a supported memory " + "source, contiguous result/passthru/mask layouts, and either " + "full physical chunks or a statically safe full-read footprint (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto gather = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGatherShape(capabilities, gather, &reason))) + return WalkResult::advance(); + gather.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only " + "for UB pointer sources, contiguous full physical chunks, " + "32-bit result elements, i32 indices, and b32 masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExpandLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.expand_load direct lowering is currently supported for " + "either a static all-active mask lowered as pto.vlds, or a " + "one-full-chunk 32-bit UB runtime mask lowered through pto.vusqz " + "+ pto.vgather2_bc + pto.vsel (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStoreShape( + capabilities, cast(store.getValue().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.store requires an 8/16/32-bit predicate-maskable " + "element type and either full physical chunks or contiguous " + "tail-store layout, with UB-backed destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedInterleaveStoreShape(capabilities, store, &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.interleave_store lowers through pto.vstsx2 only for " + "matching contiguous full low/high input chunks with a supported " + "UB destination and 8/16/32-bit element type (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupStoreShape(capabilities, store, &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_store requires contiguous full value chunks, a " + "supported UB destination, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedMaskedStoreShape( + capabilities, cast(store.getValue().getType()), + cast(store.getMask().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_store requires either full physical chunks or " + "contiguous tail-store value/mask layout, with UB-backed " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedStrideStoreShape(capabilities, store, &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.stride_store lowers through pto.vsstb only for one " + "contiguous physical value/mask chunk and a supported UB " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto scatter = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedScatterShape(capabilities, scatter, &reason))) + return WalkResult::advance(); + scatter.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.scatter lowers through pto.vscatter only with a UB " + "pointer destination, contiguous full physical chunks, 32-bit " + "value elements, i32 indices, and b32 masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + VMILayoutSupport supports; + if (succeeded( + supports.canMaterializeDataLayout(sourceType, resultType, + &reason))) + return WalkResult::advance(); + + emitEnsureLayoutMaterializationError(ensure, sourceType, resultType, + reason); + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (succeeded(checkSupportedLayoutMaterialization( + capabilities, sourceType, resultType, sourceType.getLayoutAttr(), + resultType.getLayoutAttr(), &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_mask_layout cannot materialize the requested " + "mask layout conversion (" + << reason + << "); partial/tail predicate layout materialization requires an " + "explicit packing plan"; + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + if (sourceType.getGranularity() == resultType.getGranularity()) + return WalkResult::advance(); + + std::string reason; + if (succeeded(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "non-identity mask granularity materialization requires concrete " + "b8/b16/b32 masks with matching lane count and layout (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto addf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.addf", cast(addf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (auto addi = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.addi", cast(addi.getResult().getType())); + if (auto subf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.subf", cast(subf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (auto subi = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.subi", cast(subi.getResult().getType())); + if (auto mulf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.mulf", cast(mulf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (auto muli = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.muli", cast(muli.getResult().getType())); + if (auto divf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.divf", cast(divf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto minf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.minf", cast(minf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (auto maxf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.maxf", cast(maxf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (auto negf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.negf", cast(negf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto absf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absf", cast(absf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto absi = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absi", cast(absi.getResult().getType()), + VMIElementPurpose::SignlessOrSignedI8I16I32, + "signless/signed i8/i16/i32 element type"); + if (auto sqrt = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.sqrt", cast(sqrt.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto exp = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.exp", cast(exp.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto ln = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.ln", cast(ln.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); + if (auto relu = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReluShape(capabilities, relu, &reason))) + return WalkResult::advance(); + relu.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.relu direct lowering requires physical vreg parts with " + "b8/b16/b32 predicate masks and f16/f32 element type (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto andi = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.andi", cast(andi.getResult().getType())); + if (auto ori = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.ori", cast(ori.getResult().getType())); + if (auto xori = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.xori", cast(xori.getResult().getType())); + if (auto shli = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.shli", cast(shli.getResult().getType())); + if (auto shrui = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.shrui", cast(shrui.getResult().getType())); + if (auto notOp = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.not", cast(notOp.getResult().getType())); + if (auto select = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.select", + cast(select.getResult().getType())); + + if (auto cmpf = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpf", cast(cmpf.getLhs().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpf.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto cmpi = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpi", cast(cmpi.getLhs().getType()), + VMIElementPurpose::AnyI8I16I32, + "signless/signed/unsigned i8/i16/i32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpi.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto activePrefix = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedActivePrefixIndexShape(activePrefix, &reason))) + return WalkResult::advance(); + activePrefix.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.active_prefix_index lowers through pto.vusqz only for " + "one contiguous physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compress = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressShape(compress, &reason))) + return WalkResult::advance(); + compress.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress lowers through pto.vsqz only for one " + "contiguous full physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compressStore = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressStoreShape(capabilities, + compressStore, &reason))) + return WalkResult::advance(); + compressStore.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress_store lowers through pto.vsqz + pto.vstur " + "only for one contiguous full physical chunk with a UB pointer " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddI, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addi lowers through pto.vcadd only for " + "contiguous full 32-bit integer source chunks with matching " + "mask chunks and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddF, + /*requiresReassoc=*/true, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addf lowers through pto.vcadd only with " + "reassoc, f32 contiguous full source chunks, matching mask " + "chunks, and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " + "VLane groups or through pto.vcadd with reassoc, contiguous full " + "source/mask chunks, #pto.vmi.layout " + "result " + "chunks, and num_groups deriving a group size aligned to " + "physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addi lowers through pto.vcgadd/vadd only " + "for i32 accumulator values; i8/i16 storage must be cast to i32 " + "before grouped reduction because narrow integer reductions " + "widen their result (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_maxi lowers through pto.vcgmax/vmax only " + "for i32 accumulator values; i8/i16 storage must be cast to i32 " + "before grouped reduction because narrow integer reductions " + "widen their result (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_maxf lowers through pto.vcgmax/vmax only " + "for f16/f32 values, matching source/mask chunks, " + "#pto.vmi.layout result chunks, and " + "num_groups deriving a group size aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MaxF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_maxf lowers through pto.vcmax only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MinF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_minf lowers through pto.vcmin only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto fma = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedFmaShape(capabilities, fma, &reason))) + return WalkResult::advance(); + fma.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 " + "element types (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extf = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtFShape(extf, &reason))) + return WalkResult::advance(); + + extf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extf supports contiguous 16-bit float-like or fp8-like " + "physical source chunks to f32 deinterleaved=2/4 results; " + "partial/tail is allowed only when source padding maps to result " + "padding (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto truncf = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedTruncFShape(truncf, &reason))) + return WalkResult::advance(); + + truncf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " + "to dense f16 results, f32 source layouts whose factor times the " + "result lane_stride matches the fp8-like narrowing factor, or f32 " + "group_slots(num_groups=G, slots=1) to f16 " + "group_slots(num_groups=G, slots=1) (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto fptosi = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedFPToSIShape(fptosi, &reason))) + return WalkResult::advance(); + + fptosi.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.fptosi supports f32 source chunks to matching 32-bit " + "integer result chunks with identical assigned layouts (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto sitofp = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedSIToFPShape(sitofp, &reason))) + return WalkResult::advance(); + + sitofp.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.sitofp supports 32-bit integer source chunks to " + "matching f32 result chunks with identical assigned layouts (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extsi = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtSIShape(extsi, &reason))) + return WalkResult::advance(); + + extsi.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extsi supports contiguous signed/signless 8-bit or " + "16-bit integer physical source chunks to 2x/4x wider integer " + "deinterleaved results, or matching " + "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " + "32-bit integer result (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extui = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtUIShape(extui, &reason))) + return WalkResult::advance(); + + extui.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extui supports contiguous unsigned 8-bit or 16-bit " + "integer physical source chunks to 2x/4x wider unsigned integer " + "deinterleaved results, or matching " + "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " + "32-bit integer result (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto trunci = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedTruncIShape(trunci, &reason))) + return WalkResult::advance(); + + trunci.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.trunci supports integer deinterleaved source layouts " + "whose factor is the 2x/4x narrowing multiple of the contiguous " + "or deinterleaved result layout factor, or 32-bit integer " + "group_slots(num_groups=G, slots=1 or 8) to 8/16-bit integer " + "group_slots(num_groups=G, slots=1 or 8) (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedBitcastShape(bitcast, &reason))) + return WalkResult::advance(); + + bitcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.bitcast requires matching source/result layouts with " + "identical physical arity and matching per-chunk logical bit " + "footprints (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + std::string reason; + if (succeeded( + checkSupportedChannelSplitShape(capabilities, split, &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split supports only 2 or 4 channels"; + else + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split requires source layout to be contiguous " + "or matching deinterleaved channel layout, every result layout " + "to be contiguous, and complete physical channel groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + std::string reason; + if (succeeded( + checkSupportedChannelMergeShape(capabilities, merge, &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge supports only 2 or 4 channels"; + else + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge requires every input layout to be " + "contiguous and result layout to be contiguous or matching " + "deinterleaved channel layout, with complete physical channel " + "groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto shuffle = dyn_cast(op)) { + std::string reason; + if (succeeded(computeShuffleForwardingSourceParts(shuffle, &reason))) + return WalkResult::advance(); + std::string splatReason; + if (succeeded(computeShuffleLane0SplatSourcePart(shuffle, &splatReason))) + return WalkResult::advance(); + std::string vselrReason; + if (succeeded(computeShuffleVselrPlans(shuffle, &vselrReason))) + return WalkResult::advance(); + + shuffle.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.shuffle requires physical chunk forwarding or " + "lane0 splat or vci-materializable vselr indices (forwarding: " + << reason << "; lane0 splat: " << splatReason + << "; vselr: " << vselrReason << ")"; + return WalkResult::interrupt(); + } + + if (auto constantMask = dyn_cast(op)) { + std::string reason; + if (succeeded(computeConstantMaskMaterialization(constantMask, &reason))) + return WalkResult::advance(); + + constantMask.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.constant_mask requires a dense bool constant with " + "concrete layout and b8/b16/b32 granularity (" + << reason << ")"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +struct VMIToVPTOPass : public mlir::pto::impl::VMIToVPTOBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMIToVPTOPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module))) { + signalPassFailure(); + return; + } + VMITargetCapabilityRegistry capabilities; + if (failed(verifySupportedVMIToVPTOOps(module, capabilities, + enableStableGatherMaskedLoad))) { + signalPassFailure(); + return; + } + + MLIRContext *context = module.getContext(); + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(context); + + populateVMIOneToNConversionPatterns(typeConverter, patterns, capabilities); + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns)))) { + module.emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + signalPassFailure(); + return; + } + if (failed(verifyNoResidualVMIIR(module))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMIToVPTOPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 2b881c6f6d..1ac6e7eaf3 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -3675,6 +3675,22 @@ static FailureOr buildVgather2Callee(MLIRContext *context, .getValue(); } +static FailureOr getVgather2OffsetsCarrierType(PatternRewriter &rewriter, + Type resultType, + Type offsetsType) { + Type elementType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elementType || !lanes) + return failure(); + + if (pto::getPTOStorageElemBitWidth(elementType) == 16) { + if (*lanes % 2 != 0) + return failure(); + return VectorType::get({*lanes / 2}, rewriter.getI32Type()); + } + return offsetsType; +} + static FailureOr buildVgather2BcCallee(MLIRContext *context, Type resultType) { return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); @@ -7210,13 +7226,22 @@ class LowerVgather2OpPattern final if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + Value offsets = adaptor.getOffsets(); + FailureOr offsetsCarrierType = getVgather2OffsetsCarrierType( + rewriter, op.getResult().getType(), offsets.getType()); + if (failed(offsetsCarrierType)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 offsets carrier"); + if (offsets.getType() != *offsetsCarrierType) + offsets = rewriter.create(op.getLoc(), *offsetsCarrierType, + offsets); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + TypeRange{adaptor.getSource().getType(), *offsetsCarrierType, adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + ValueRange{adaptor.getSource(), offsets, adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); @@ -9289,6 +9314,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10164,6 +10224,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index cd501fc420..3a3bbcc926 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -3611,6 +3611,22 @@ static FailureOr buildVgather2Callee(MLIRContext *context, .getValue(); } +static FailureOr getVgather2OffsetsCarrierType(PatternRewriter &rewriter, + Type resultType, + Type offsetsType) { + Type elementType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elementType || !lanes) + return failure(); + + if (pto::getPTOStorageElemBitWidth(elementType) == 16) { + if (*lanes % 2 != 0) + return failure(); + return VectorType::get({*lanes / 2}, rewriter.getI32Type()); + } + return offsetsType; +} + static FailureOr buildVgather2BcCallee(MLIRContext *context, Type resultType) { return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); @@ -7152,13 +7168,22 @@ class LowerVgather2OpPattern final if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + Value offsets = adaptor.getOffsets(); + FailureOr offsetsCarrierType = getVgather2OffsetsCarrierType( + rewriter, op.getResult().getType(), offsets.getType()); + if (failed(offsetsCarrierType)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 offsets carrier"); + if (offsets.getType() != *offsetsCarrierType) + offsets = rewriter.create(op.getLoc(), *offsetsCarrierType, + offsets); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + TypeRange{adaptor.getSource().getType(), *offsetsCarrierType, adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + ValueRange{adaptor.getSource(), offsets, adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); @@ -9300,6 +9325,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10177,6 +10237,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp b/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp new file mode 100644 index 0000000000..22dbbfc094 --- /dev/null +++ b/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTONORMALIZEEQUIVALENTVCVT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isOddPart(StringRef part) { + return part == "ODD" || part == "PART_ODD"; +} + +static bool isAllTrueMask(Value mask) { + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + return false; +} + +static bool isPairEquivalentLoadDist(StringRef dist) { + return dist == "BRC_B8" || dist == "BRC_B16" || dist == "BRC_B32" || + dist == "US_B8" || dist == "US_B16" || dist == "E2B_B16" || + dist == "E2B_B32"; +} + +static bool hasEvenOddEquivalentLanes(Value value) { + if (value.getDefiningOp()) + return true; + + auto load = value.getDefiningOp(); + if (!load || value != load.getResult()) + return false; + + std::optional dist = load.getDist(); + return dist && isPairEquivalentLoadDist(*dist); +} + +static bool isNarrowToWideVcvt(VcvtOp op) { + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return false; + + unsigned inputBits = getPTOStorageElemBitWidth(inputType.getElementType()); + unsigned resultBits = getPTOStorageElemBitWidth(resultType.getElementType()); + return inputBits != 0 && resultBits != 0 && inputBits < resultBits; +} + +struct VPTONormalizeEquivalentVcvtPass + : public pto::impl::VPTONormalizeEquivalentVcvtBase< + VPTONormalizeEquivalentVcvtPass> { + void runOnOperation() override { + StringAttr even = StringAttr::get(&getContext(), "EVEN"); + + getOperation().walk([&](VcvtOp op) { + std::optional part = op.getPart(); + if (!part || !isOddPart(*part)) + return; + if (!isNarrowToWideVcvt(op)) + return; + if (!isAllTrueMask(op.getMask())) + return; + if (!hasEvenOddEquivalentLanes(op.getInput())) + return; + + op.setPartAttr(even); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTONormalizeEquivalentVcvtPass() { + return std::make_unique(); +} diff --git a/test/lit/CMakeLists.txt b/test/lit/CMakeLists.txt index 684ae0bf50..9fb0e63c5b 100644 --- a/test/lit/CMakeLists.txt +++ b/test/lit/CMakeLists.txt @@ -27,6 +27,7 @@ configure_lit_site_cfg( set(PTOIR_TEST_DEPENDS FileCheck count not pto-opt + pto-test-opt ) add_lit_testsuite(check-pto "Running the pto regression tests" diff --git a/test/lit/lit.cfg.py b/test/lit/lit.cfg.py index 9a81959f47..43cb6724e0 100644 --- a/test/lit/lit.cfg.py +++ b/test/lit/lit.cfg.py @@ -40,6 +40,8 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.ptoir_obj_root, 'test/lit') config.ptoir_tools_dir = os.path.join(config.ptoir_obj_root, 'tools/ptoas') +config.ptoir_test_tools_dir = os.path.join(config.ptoir_obj_root, + 'tools/pto-test-opt') config.substitutions.append(('%PATH%', config.environment['PATH'])) config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) @@ -57,9 +59,11 @@ # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -tool_dirs = [config.ptoir_tools_dir, config.llvm_tools_dir] +tool_dirs = [config.ptoir_tools_dir, config.ptoir_test_tools_dir, + config.llvm_tools_dir] tools = [ 'ptoas', + 'pto-test-opt', ] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/test/lit/vmi/opt/README.md b/test/lit/vmi/opt/README.md new file mode 100644 index 0000000000..c169dcee08 --- /dev/null +++ b/test/lit/vmi/opt/README.md @@ -0,0 +1,18 @@ + + +# VMI Optimization Shape Guards + +This directory contains end-to-end VMI optimization capability tests. + +These tests intentionally check generated VPTO instruction shape for representative +kernels. They are not generic correctness tests: a failure means the VMI pipeline +has likely regressed an optimization contract and should not be updated away +without an explicit replacement optimization decision. diff --git a/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto b/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto new file mode 100644 index 0000000000..21c0f52c2f --- /dev/null +++ b/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --implicit-check-not=pto.vdintlv --implicit-check-not=pto.vintlv --implicit-check-not='part = "ODD"' + +// This is an optimization capability guard for the VMI ComputeMropeF16 path. +// Do not weaken the checks when the output shape regresses. The intended shape is: +// - 64xf16 loads feeding f16->f32 ext lower as UNPK_B16 loads plus one EVEN vcvt; +// - no ODD vcvt, vintlv, or vdintlv is needed for those ext paths; +// - f32->f16 trunc feeding masked_store lowers through EVEN vcvt plus vpack. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeMropeF16( + %x_ub_u: !pto.ptr, + %y_ub_u: !pto.ptr, + %cs_ub_u: !pto.ptr, + %curTokens: i32, + %num_heads: i32, + %num_heads_max: i32, + %head_size: i32, + %rotary_dim: i32, + %headAlign_fp16: i32, + %is_neox: i1) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + + %x_ub = pto.castptr %x_ub_u : !pto.ptr -> !pto.ptr + %y_ub = pto.castptr %y_ub_u : !pto.ptr -> !pto.ptr + %cs_ub = pto.castptr %cs_ub_u : !pto.ptr -> !pto.ptr + + %cur_tokens = arith.index_cast %curTokens : i32 to index + %num_heads_idx = arith.index_cast %num_heads : i32 to index + %num_heads_max_idx = arith.index_cast %num_heads_max : i32 to index + %head_align_idx = arith.index_cast %headAlign_fp16 : i32 to index + %rotary_dim_idx = arith.index_cast %rotary_dim : i32 to index + + %half_rotary_i32 = arith.divui %rotary_dim, %c2_i32 : i32 + %suffix_len_i32 = arith.subi %head_size, %rotary_dim : i32 + %has_suffix = arith.cmpi sgt, %head_size, %rotary_dim : i32 + + %half_rotary = arith.index_cast %half_rotary_i32 : i32 to index + %suffix_len = arith.index_cast %suffix_len_i32 : i32 to index + %half_mask = pto.vmi.create_mask %half_rotary : index -> !pto.vmi.mask<64xpred> + %suffix_mask = pto.vmi.create_mask %suffix_len : index -> !pto.vmi.mask<64xpred> + + %head_stride = arith.muli %num_heads_max_idx, %head_align_idx : index + + pto.vecscope { + scf.if %is_neox { + scf.for %ti = %c0 to %cur_tokens step %c1 { + %cs_off = arith.muli %ti, %rotary_dim_idx : index + %token_base = arith.muli %ti, %head_stride : index + %cs_sin_base = arith.addi %cs_off, %half_rotary : index + + scf.for %h = %c0 to %num_heads_idx step %c1 { + %head_off = arith.muli %h, %head_align_idx : index + %x_base = arith.addi %token_base, %head_off : index + %x2_base = arith.addi %x_base, %half_rotary : index + + %x1 = pto.vmi.load %x_ub[%x_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x2 = pto.vmi.load %x_ub[%x2_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %cos = pto.vmi.load %cs_ub[%cs_off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %sin = pto.vmi.load %cs_ub[%cs_sin_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + + %x1_f32 = pto.vmi.extf %x1 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %x2_f32 = pto.vmi.extf %x2 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %cos_f32 = pto.vmi.extf %cos + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %sin_f32 = pto.vmi.extf %sin + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + + %x1_cos = pto.vmi.mulf %x1_f32, %cos_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %x2_sin = pto.vmi.mulf %x2_f32, %sin_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out1_f32 = pto.vmi.subf %x1_cos, %x2_sin + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out1 = pto.vmi.truncf %out1_f32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %out1, %y_ub[%x_base], %half_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + + %x2_cos = pto.vmi.mulf %x2_f32, %cos_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %x1_sin = pto.vmi.mulf %x1_f32, %sin_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out2_f32 = pto.vmi.addf %x2_cos, %x1_sin + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out2 = pto.vmi.truncf %out2_f32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %out2, %y_ub[%x2_base], %half_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + + scf.if %has_suffix { + %suffix_base = arith.addi %x_base, %rotary_dim_idx : index + %suffix = pto.vmi.load %x_ub[%suffix_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %suffix, %y_ub[%suffix_base], %suffix_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + } + } + } + } + } + return + } +} + +// CHECK-LABEL: func.func @ComputeMropeF16 +// CHECK-COUNT-4: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK-COUNT-4: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vpack +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: pto.vadd +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vpack +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<128xf16> diff --git a/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto b/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto new file mode 100644 index 0000000000..c95d51e506 --- /dev/null +++ b/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --check-prefix=VPTO --implicit-check-not=pto.vmi. --implicit-check-not='!pto.vmi' + +// Optimization guard for the ComputeSingleRowVF block-quant path. +// The 128xf32 -> 128xf8 truncf should use a deinterleaved=2 f32 source and a +// lane_stride=2 fp8 result, lowering to P0/P2 only. This keeps the output +// physical register count unchanged while avoiding the extra f32 source parts +// required by deinterleaved=4. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeSingleRowVF_fp16_vmi_opt( + %inUb16_u: !pto.ptr, + %scaleUb: !pto.ptr, + %outUb8_u: !pto.ptr, + %fp8MaxValue: f32, + %minScale: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %one = arith.constant 1.000000e+00 : f32 + %pos_inf = arith.constant 0x7F800000 : f32 + + %inUb = pto.castptr %inUb16_u : !pto.ptr -> !pto.ptr + %outUbFp = pto.castptr %outUb8_u : !pto.ptr -> !pto.ptr + %recip_min_scale = arith.divf %one, %minScale : f32 + + pto.vecscope { + %fp8max1 = pto.vmi.broadcast %fp8MaxValue : f32 -> !pto.vmi.vreg<1xf32> + %limit_vec = pto.vmi.broadcast %recip_min_scale : f32 -> !pto.vmi.vreg<128xf32> + %pos_inf_vec = pto.vmi.broadcast %pos_inf : f32 -> !pto.vmi.vreg<128xf32> + + %mask128 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %inUb[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask128 {num_groups = 1} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<1xf32> + %scale_raw = pto.vmi.divf %amax, %fp8max1 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %scale_raw_vec = pto.vmi.group_broadcast %scale_raw {num_groups = 1} + : !pto.vmi.vreg<1xf32> -> !pto.vmi.vreg<128xf32> + %finite_mask = pto.vmi.cmpf "olt", %scale_raw_vec, %pos_inf_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %scale_clamped_vec = pto.vmi.minf %scale_raw_vec, %limit_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.select %finite_mask, %scale_clamped_vec, %scale_raw_vec + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scale = pto.vmi.group_reduce_maxf %scale_vec, %mask128 {num_groups = 1} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<1xf32> + + pto.vmi.group_store %scale, %scaleUb[%c0], %c1 {num_groups = 1} + : !pto.vmi.vreg<1xf32>, !pto.ptr + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + pto.vmi.masked_store %q8, %outUbFp[%c0], %mask128 + : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + return + } +} + +// ASSIGN-LABEL: func.func @ComputeSingleRowVF_fp16_vmi_opt( +// ASSIGN: %[[Q:.*]] = pto.vmi.divf {{.*}} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[STORE:.*]] = pto.vmi.ensure_layout %[[Q8]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[STORE]] + +// VPTO-LABEL: func.func @ComputeSingleRowVF_fp16_vmi_opt( +// VPTO: pto.vcvt {{.*}} {part = "EVEN"} +// VPTO: pto.vcvt {{.*}} {part = "ODD"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vor +// VPTO: pto.vpack +// VPTO: pto.vsts diff --git a/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto b/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto new file mode 100644 index 0000000000..448c4d524d --- /dev/null +++ b/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --implicit-check-not=pto.vdintlv --implicit-check-not=pto.vintlv --implicit-check-not=pto.vpack + +// This is an optimization capability guard for the VMI ComputeY1ToFP8 FP16 path. +// Do not weaken the checks when the output shape regresses. The intended shape is: +// - one E2B_B16 scale load per kernel, outside the block loop; +// - DINTLV_B16 x loads inside the block loop; +// - four f32 streams quantized through P0/P1/P2/P3 and merged by vor; +// - contiguous fp8 store, with no extra register layout materialization. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeY1ToFP8_fp16_e4m3_vmi( + %dataLen: i16, + %blockCount: i16, + %xAddr: !pto.ptr, + %mxScale1ReciprocalAddr: !pto.ptr, + %y1Addr: !pto.ptr, + %ubBlockSize: i16, + %vlForHalfNumber: i16) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %block_count = arith.index_cast %blockCount : i16 to index + %vl_half = arith.index_cast %vlForHalfNumber : i16 to index + %load_stride_y8 = arith.muli %vl_half, %c2 : index + + pto.vecscope { + %scale_f16 = pto.vmi.group_slot_load %mxScale1ReciprocalAddr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf16> + %scale_f16_vec = pto.vmi.group_broadcast %scale_f16 {num_groups = 8} + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<256xf16> + %scale_fp32 = pto.vmi.extf %scale_f16_vec + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + + scf.for %i = %c0 to %block_count step %c1 { + %x_off = arith.muli %i, %load_stride_y8 : index + %y_off = arith.muli %i, %load_stride_y8 : index + + %x_f16 = pto.vmi.load %xAddr[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x_fp32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %res_fp32 = pto.vmi.mulf %x_fp32, %scale_fp32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + + %res_fp8 = pto.vmi.truncf %res_fp32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + + pto.vmi.store %res_fp8, %y1Addr[%y_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + return + } + + func.func @ComputeY1ToFP8_fp16_e5m2_vmi( + %dataLen: i16, + %blockCount: i16, + %xAddr: !pto.ptr, + %mxScale1ReciprocalAddr: !pto.ptr, + %y1Addr: !pto.ptr, + %ubBlockSize: i16, + %vlForHalfNumber: i16) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %block_count = arith.index_cast %blockCount : i16 to index + %vl_half = arith.index_cast %vlForHalfNumber : i16 to index + %load_stride_y8 = arith.muli %vl_half, %c2 : index + + pto.vecscope { + %scale_f16 = pto.vmi.group_slot_load %mxScale1ReciprocalAddr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf16> + %scale_f16_vec = pto.vmi.group_broadcast %scale_f16 {num_groups = 8} + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<256xf16> + %scale_fp32 = pto.vmi.extf %scale_f16_vec + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + + scf.for %i = %c0 to %block_count step %c1 { + %x_off = arith.muli %i, %load_stride_y8 : index + %y_off = arith.muli %i, %load_stride_y8 : index + + %x_f16 = pto.vmi.load %xAddr[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x_fp32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %res_fp32 = pto.vmi.mulf %x_fp32, %scale_fp32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + + %res_fp8 = pto.vmi.truncf %res_fp32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + + pto.vmi.store %res_fp8, %y1Addr[%y_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + } + return + } +} + +// CHECK-LABEL: func.func @ComputeY1ToFP8_fp16_e4m3_vmi +// CHECK: pto.vlds {{.*}} {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B16" : !pto.ptr, index -> !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @ComputeY1ToFP8_fp16_e5m2_vmi +// CHECK: pto.vlds {{.*}} {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B16" : !pto.ptr, index -> !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E5M2>, !pto.ptr, !pto.mask diff --git a/test/lit/vmi/vmi_absf_integer_invalid.pto b/test/lit/vmi/vmi_absf_integer_invalid.pto new file mode 100644 index 0000000000..2a3900e4e5 --- /dev/null +++ b/test/lit/vmi/vmi_absf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absf_integer_invalid(%value: !pto.vmi.vreg<128xi16>) { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.absf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_absi_float_invalid.pto b/test/lit/vmi/vmi_absi_float_invalid.pto new file mode 100644 index 0000000000..0f2d556c1a --- /dev/null +++ b/test/lit/vmi/vmi_absi_float_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absi_float_invalid(%value: !pto.vmi.vreg<64xf32>) { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.absi' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto new file mode 100644 index 0000000000..c675b2e6e9 --- /dev/null +++ b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_active_prefix_index_result_type_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.active_prefix_index' op requires signless integer result element type diff --git a/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..bd6ed94bac --- /dev/null +++ b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_lane_mismatch_invalid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<64xf32>) { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires all VMI data values to have the same logical lane count diff --git a/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto new file mode 100644 index 0000000000..937889d014 --- /dev/null +++ b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_bitcast_total_bits_invalid(%value: !pto.vmi.vreg<128xf32>) { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.bitcast' op requires source and result to carry the same total number of bits diff --git a/test/lit/vmi/vmi_bitwise_float_invalid.pto b/test/lit/vmi/vmi_bitwise_float_invalid.pto new file mode 100644 index 0000000000..60f260d444 --- /dev/null +++ b/test/lit/vmi/vmi_bitwise_float_invalid.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_andi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.andi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.andi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_ori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.ori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.ori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_xori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.xori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.xori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_not_float_invalid(%source: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.not %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.not' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto new file mode 100644 index 0000000000..9ecdc9469f --- /dev/null +++ b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_type_mismatch_invalid(%value: f16) { + %result = pto.vmi.broadcast %value : f16 -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires scalar or VMI vector input element type to match result element type diff --git a/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto new file mode 100644 index 0000000000..1dbc569c4c --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_input_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<3xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<3xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires all channel inputs to have the same lane count and element type diff --git a/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto new file mode 100644 index 0000000000..f5c7ad94b9 --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_result_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<2xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires result lane count and element type to match merged channels diff --git a/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto new file mode 100644 index 0000000000..bbf923b079 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_lane_count_invalid( + %src: !pto.vmi.vreg<5xf32>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<5xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + return + } +} + +// CHECK: requires source lane count to equal result count times per-channel lane count diff --git a/test/lit/vmi/vmi_channel_split_result_count_invalid.pto b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto new file mode 100644 index 0000000000..bbe2b434d6 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_result_count_invalid( + %src: !pto.vmi.vreg<4xf32>) { + %ch0 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32>) -> !pto.vmi.vreg<4xf32> + return + } +} + +// CHECK: requires at least two channel results diff --git a/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto new file mode 100644 index 0000000000..7e7e6bb66f --- /dev/null +++ b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_compress_result_mismatch_invalid( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.compress' op requires all VMI data values to have the same element type diff --git a/test/lit/vmi/vmi_constant_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto new file mode 100644 index 0000000000..c1ff60fe3b --- /dev/null +++ b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_attr_kind_invalid() { + %value = "pto.vmi.constant"() { + value = 1.000000e+00 : f32 + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense elements constant attribute diff --git a/test/lit/vmi/vmi_constant_element_count_invalid.pto b/test/lit/vmi/vmi_constant_element_count_invalid.pto new file mode 100644 index 0000000000..b5e80ce364 --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_count_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<64xf32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_element_type_invalid.pto b/test/lit/vmi/vmi_constant_element_type_invalid.pto new file mode 100644 index 0000000000..29a5f2d22a --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_type_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element type to match result element type diff --git a/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto new file mode 100644 index 0000000000..537d007f03 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_attr_kind_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = true + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense elements mask constant attribute diff --git a/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto new file mode 100644 index 0000000000..f39f4ab00a --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_count_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<64xi1> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto new file mode 100644 index 0000000000..7f97a4afd6 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_type_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element type to be i1 diff --git a/test/lit/vmi/vmi_create_group_mask_invalid.pto b/test/lit/vmi/vmi_create_group_mask_invalid.pto new file mode 100644 index 0000000000..0c3aec3d65 --- /dev/null +++ b/test/lit/vmi/vmi_create_group_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_create_group_mask_lane_count_invalid() { + %c12 = arith.constant 12 : index + // CHECK: pto.vmi.create_group_mask + // CHECK-SAME: requires result lane count to equal num_groups * group_size + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<127xpred> + return + } +} diff --git a/test/lit/vmi/vmi_divf_integer_invalid.pto b/test/lit/vmi/vmi_divf_integer_invalid.pto new file mode 100644 index 0000000000..0c26d668b3 --- /dev/null +++ b/test/lit/vmi/vmi_divf_integer_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_divf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.divf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_elementwise_kind_invalid.pto b/test/lit/vmi/vmi_elementwise_kind_invalid.pto new file mode 100644 index 0000000000..46e8255de8 --- /dev/null +++ b/test/lit/vmi/vmi_elementwise_kind_invalid.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_subf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.subf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.subf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_subi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.subi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.subi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_mulf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.mulf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.mulf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_muli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.muli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.muli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto new file mode 100644 index 0000000000..09a92692de --- /dev/null +++ b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_ensure_layout_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_surface_invalid( + %a: !pto.vmi.mask<128xpred>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_layout_mismatch_invalid( + %a: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result mask layouts to match diff --git a/test/lit/vmi/vmi_extf_direction_invalid.pto b/test/lit/vmi/vmi_extf_direction_invalid.pto new file mode 100644 index 0000000000..e00280a69d --- /dev/null +++ b/test/lit/vmi/vmi_extf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_direction_invalid(%source: !pto.vmi.vreg<128xf32>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires result element type to be wider than source element type diff --git a/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..d1b64fc15d --- /dev/null +++ b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf16>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_fma_integer_invalid.pto b/test/lit/vmi/vmi_fma_integer_invalid.pto new file mode 100644 index 0000000000..e44d8879b3 --- /dev/null +++ b/test/lit/vmi/vmi_fma_integer_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_fma_integer_invalid( + %lhs: !pto.vmi.vreg<64xi32>, + %rhs: !pto.vmi.vreg<64xi32>, + %acc: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<64xi32>, + !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %out : !pto.vmi.vreg<64xi32> + } +} + +// CHECK: 'pto.vmi.fma' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_gather_indices_invalid.pto b/test/lit/vmi/vmi_gather_indices_invalid.pto new file mode 100644 index 0000000000..4b37624430 --- /dev/null +++ b/test/lit/vmi/vmi_gather_indices_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_gather_indices_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xf32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.gather' op requires signless or unsigned 16-bit or 32-bit integer indices diff --git a/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..33f3516efd --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i16_group_reduce_invalid( + %source: !pto.vmi.vreg<128xi16>, + %mask: !pto.vmi.mask<128xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xi16> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto new file mode 100644 index 0000000000..973a57450e --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i8_group_reduce_invalid( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xi8> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto new file mode 100644 index 0000000000..756a8b3527 --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i8_group_reduce_max_invalid( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xpred>) { + %max = pto.vmi.group_reduce_maxi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xi8> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_interleaved_memory_ops.pto b/test/lit/vmi/vmi_interleaved_memory_ops.pto new file mode 100644 index 0000000000..26aa6324ef --- /dev/null +++ b/test/lit/vmi/vmi_interleaved_memory_ops.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_deinterleave_load( + %src: !pto.ptr, + %offset: index) -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) { + %low, %high = pto.vmi.deinterleave_load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + return %low, %high : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + } + + func.func @vmi_interleave_store( + %low: !pto.vmi.vreg<64xf32>, + %high: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.interleave_store %low, %high, %dst[%offset] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_deinterleave_load( +// ASSIGN-SAME: -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout>) +// ASSIGN: %[[LOW:.*]], %[[HIGH:.*]] = pto.vmi.deinterleave_load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: return %[[LOW]], %[[HIGH]] : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_deinterleave_load( +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: return %[[LOW]], %[[HIGH]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_interleave_store( +// ASSIGN-SAME: %[[LOW_ARG:[^:]+]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[HIGH_ARG:[^:]+]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.interleave_store %[[LOW_ARG]], %[[HIGH_ARG]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_interleave_store( +// LOWER: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// LOWER: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B32", %[[MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto b/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto new file mode 100644 index 0000000000..81aaa858ae --- /dev/null +++ b/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_interleave_store_mismatch( + %low: !pto.vmi.vreg<64xf32>, + %high: !pto.vmi.vreg<128xf32>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.interleave_store %low, %high, %dst[%offset] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK: 'pto.vmi.interleave_store' op requires all VMI data values to have the same logical lane count diff --git a/test/lit/vmi/vmi_iota_element_type_invalid.pto b/test/lit/vmi/vmi_iota_element_type_invalid.pto new file mode 100644 index 0000000000..448fba485f --- /dev/null +++ b/test/lit/vmi/vmi_iota_element_type_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_element_type_invalid(%base: i64) { + %value = pto.vmi.iota %base + : i64 -> !pto.vmi.vreg<64xi64> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires result element type to be integer 8/16/32 or f16/f32 diff --git a/test/lit/vmi/vmi_iota_order_invalid.pto b/test/lit/vmi/vmi_iota_order_invalid.pto new file mode 100644 index 0000000000..93df56591c --- /dev/null +++ b/test/lit/vmi/vmi_iota_order_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_order_invalid(%base: i32) { + %value = pto.vmi.iota %base {order = "DOWN"} + : i32 -> !pto.vmi.vreg<64xi32> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires order to be ASC or DESC diff --git a/test/lit/vmi/vmi_lane_stride_dense_load_store.pto b/test/lit/vmi/vmi_lane_stride_dense_load_store.pto new file mode 100644 index 0000000000..2cc9cb2625 --- /dev/null +++ b/test/lit/vmi/vmi_lane_stride_dense_load_store.pto @@ -0,0 +1,197 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_load_lane_stride( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + %strided = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_store_lane_stride( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + pto.vmi.store %compact, %dst[%off] + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_extf_lane_stride_even( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %wide = pto.vmi.extf %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %wide : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_extui_lane_stride_even( + %value: !pto.vmi.vreg<64xui16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui32, #pto.vmi.layout> { + %wide = pto.vmi.extui %value + : !pto.vmi.vreg<64xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui32, #pto.vmi.layout> + return %wide : !pto.vmi.vreg<64xui32, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_store_lane_stride_b32( + %value: !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf32, #pto.vmi.layout> + pto.vmi.store %compact, %dst[%off] + : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_layout_fold_load_store_lane_stride4_u8( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + %strided = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + pto.vmi.store %strided, %dst[%off] + : !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_ensure_contiguous_to_lane_stride_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %strided = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_ensure_lane_stride_to_contiguous_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %compact : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_ensure_contiguous_to_lane_stride4_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> { + %strided = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + } + + func.func @vmi_ensure_lane_stride4_to_contiguous_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + return %compact : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_load_lane_stride( +// FOLD: %[[LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return %[[LOAD]] + +// FOLD-LABEL: func.func @vmi_layout_fold_store_lane_stride( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[VALUE]] +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout + +// FOLD-LABEL: func.func @vmi_extf_lane_stride_even( + +// LOWER-LABEL: func.func @vmi_layout_fold_load_lane_stride( +// LOWER: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_store_lane_stride( +// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_extf_lane_stride_even( +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// LOWER-NOT: part = "ODD" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_extui_lane_stride_even( +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// LOWER-NOT: part = "ODD" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_store_lane_stride_b32( +// LOWER: pto.vsts {{.*}} {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_load_store_lane_stride4_u8( +// LOWER: pto.vlds {{.*}} {dist = "UNPK4"} : !pto.ptr -> !pto.vreg<256xui8> +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_contiguous_to_lane_stride_f16( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xf16> -> !pto.vreg<128xui16> +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<64xui32> -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_lane_stride_to_contiguous_f16( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xf16> -> !pto.vreg<64xui32> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xui32> -> !pto.vreg<128xui16> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_contiguous_to_lane_stride4_ui8( +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<256xui8> -> !pto.vreg<128xui16> +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_lane_stride4_to_contiguous_ui8( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<256xui8> -> !pto.vreg<64xui32> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xui32> -> !pto.vreg<128xui16> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_lane_stride_masked_store.pto b/test/lit/vmi/vmi_lane_stride_masked_store.pto new file mode 100644 index 0000000000..25d1c661fe --- /dev/null +++ b/test/lit/vmi/vmi_lane_stride_masked_store.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_lane_stride2_masked_store_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<64xb16, #pto.vmi.layout> + -> !pto.vmi.mask<64xb16, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb16, #pto.vmi.layout> + return + } + + func.func @vmi_lane_stride4_masked_store_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb8, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<64xb8, #pto.vmi.layout> + -> !pto.vmi.mask<64xb8, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb8, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_lane_stride2_masked_store_f16( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<64xb16, #pto.vmi.layout> + +// FOLD-LABEL: func.func @vmi_lane_stride4_masked_store_ui8( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xui8, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb8, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<64xui8, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<64xb8, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_lane_stride2_masked_store_f16( +// LOWER-SAME: %[[VALUE:[^,]+]]: !pto.vreg<128xf16> +// LOWER-SAME: %[[MASK:[^,]+]]: !pto.mask +// LOWER: %[[COMPACT:.*]] = pto.punpack %[[MASK]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: pto.vsts %[[VALUE]], {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_lane_stride4_masked_store_ui8( +// LOWER-SAME: %[[VALUE:[^,]+]]: !pto.vreg<256xui8> +// LOWER-SAME: %[[MASK:[^,]+]]: !pto.mask +// LOWER: %[[MID:.*]] = pto.punpack %[[MASK]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: %[[COMPACT:.*]] = pto.punpack %[[MID]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: pto.vsts %[[VALUE]], {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto new file mode 100644 index 0000000000..5dabf59203 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_active_prefix_index(%mask: !pto.vmi.mask<64xpred>) + -> !pto.vmi.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xi32> + return %idx : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_active_prefix_index( +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: %[[IDX:.*]] = pto.vmi.active_prefix_index %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: return %[[IDX]] diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto new file mode 100644 index 0000000000..4c8c8e3142 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %scale: f32) { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY:.*]] = pto.vmi.addf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY_DENSE:.*]] = pto.vmi.ensure_layout %[[COPY]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[COPY_DENSE]] +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[PROD:.*]] = pto.vmi.mulf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto new file mode 100644 index 0000000000..eefe95d973 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_remat( + %scalar: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %broadcast, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %broadcast, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// ASSIGN-SAME: %[[SCALAR:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[BCAST_DEINT:.*]] = pto.vmi.broadcast %[[SCALAR]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[BCAST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.ensure_layout %[[BCAST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[BCAST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// LOWER-COUNT-4: pto.vdup %arg0 +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto new file mode 100644 index 0000000000..c7049c75a2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } + + func.func @caller(%base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c32 = arith.constant 32 : index + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return + } +} + +// ASSIGN-LABEL: func.func private @consume( +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-LABEL: func.func @caller( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: call @consume(%[[X]], %[[MASK]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> + +// LOWER-LABEL: func.func private @consume( +// LOWER-SAME: !pto.vreg<64xf32> +// LOWER-SAME: !pto.mask +// LOWER: pto.vdintlv +// LOWER: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-LABEL: func.func @caller( +// LOWER: call @consume( +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_call_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto new file mode 100644 index 0000000000..b7245ad00b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-LABEL: func.func @caller( +// CHECK: %[[EA:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[R:.*]] = call @callee(%[[EA]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[R]], %[[R]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_branch.pto b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto new file mode 100644 index 0000000000..f96962a580 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %then_mask = pto.vmi.cmpf "olt", %then_value, %then_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%then_value, %then_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_mask = pto.vmi.cmpf "olt", %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%else_value, %else_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^join(%value: !pto.vmi.vreg<128xf32>, %mask: !pto.vmi.mask<128xpred>): + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_cf_branch( +// CHECK: cf.cond_br +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: ^{{.*}}(%{{.*}}: !pto.vmi.vreg<128xf16, #pto.vmi.layout>): +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: cf.br +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: ^{{.*}}(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout>): +// CHECK: pto.vmi.select %[[MASK]], %[[VALUE]], %[[VALUE]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_switch.pto b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto new file mode 100644 index 0000000000..6376a5502c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_cf_switch( + %flag: i32, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.switch %flag : i32, [ + default: ^join(%a : !pto.vmi.vreg<128xf32>), + 0: ^join(%b : !pto.vmi.vreg<128xf32>) + ] + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_cf_switch( +// ASSIGN-SAME: %[[FLAG:[^:]+]]: i32 +// ASSIGN-SAME: %[[A:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[B:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: cf.switch %[[FLAG]] : i32, [ +// ASSIGN: default: ^bb1(%[[A]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>), +// ASSIGN: 0: ^bb1(%[[B]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN: ^bb1(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: return %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_cf_switch( +// LOWER-SAME: %[[FLAG:[^:]+]]: i32 +// LOWER-SAME: %[[A0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[A1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: cf.switch %[[FLAG]] : i32, [ +// LOWER: default: ^bb1(%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), +// LOWER: 0: ^bb1(%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: ^bb1(%[[VALUE0:.*]]: !pto.vreg<64xf32>, %[[VALUE1:.*]]: !pto.vreg<64xf32>): +// LOWER: return %[[VALUE0]], %[[VALUE1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..351b3f62f8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32>, + %ch1: !pto.vmi.vreg<64xf32>, + %ch2: !pto.vmi.vreg<64xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<192xf32> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..572845c1a4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_compress.pto b/test/lit/vmi/vmi_layout_assignment_compress.pto new file mode 100644 index 0000000000..dee109ce28 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress( + %src: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress( +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.compress %[[SRC]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_compress_store.pto b/test/lit/vmi/vmi_layout_assignment_compress_store.pto new file mode 100644 index 0000000000..93266bdf42 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress_store.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress_store( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress_store( +// CHECK-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[DST:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.compress_store %[[VALUE]], %[[DST]][%[[OFFSET]]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto new file mode 100644 index 0000000000..a426621c15 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_constant_remat( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %constant, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %constant, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_remat( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[CONST_DEINT:.*]] = "pto.vmi.constant"() +// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[CONST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[CONST_CONTIG:.*]] = pto.vmi.ensure_layout %[[CONST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[CONST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_constant_remat( +// LOWER: arith.constant 1.000000e+00 : f32 +// LOWER-COUNT-4: pto.vdup +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto new file mode 100644 index 0000000000..868624d330 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s16( + %base: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// LOWER: pto.pset_b32 "PAT_ALL" +// LOWER: pto.plt_b32 +// LOWER: pto.pnot +// LOWER: pto.pand +// LOWER: pto.por +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: PAT_M4 +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto new file mode 100644 index 0000000000..447f4591b7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %active_cols: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// ASSIGN-SAME: %[[ACTIVE:arg[0-9]+]]: index) +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK1:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// LOWER: arith.index_cast +// LOWER: arith.maxsi +// LOWER: arith.minui +// LOWER: pto.vci +// LOWER: pto.vshrs +// LOWER: pto.vshls +// LOWER: pto.vsub +// LOWER-COUNT-8: pto.vcmps +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto new file mode 100644 index 0000000000..7238c10a42 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -0,0 +1,186 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_multichunk_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_compact_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_compact_f16_to_f32_masked_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %c64 = arith.constant 64 : index + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %y16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %y16, %dst[%off], %mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } + + func.func @vmi_layout_assignment_dense_f32_to_f16_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %x16, %dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_multichunk_f32_to_f16_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf16> + pto.vmi.store %x16, %dst[%off] + : !pto.vmi.vreg<256xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// LOWER: pto.vlds +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} +// LOWER: pto.vcvt {{.*}} {part = "ODD"} +// LOWER: pto.vintlv +// LOWER-COUNT-2: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_multichunk_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: pto.vmi.store %[[X32]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_store( +// LOWER: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// LOWER-NOT: part = "ODD" +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_masked_store( +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[Y16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[Y16C:.*]] = pto.vmi.ensure_layout %[[Y16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_granularity %[[MASK0]] +// ASSIGN-SAME: -> !pto.vmi.mask<64xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[Y16C]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, !pto.ptr, !pto.vmi.mask<64xb16, #pto.vmi.layout> + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_multichunk_f32_to_f16_store( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X16]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// LOWER: pto.vldsx2 {{.*}} "DINTLV_B32" +// LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto new file mode 100644 index 0000000000..a7be1a67a4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %copy_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// LOWER-COUNT-4: pto.vlds +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-COUNT-4: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto new file mode 100644 index 0000000000..a92d0c52f2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_dense_store_group_slots_invalid( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xf32> + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) + pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_dhist.pto b/test/lit/vmi/vmi_layout_assignment_dhist.pto new file mode 100644 index 0000000000..89d9aee9e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dhist.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_dhist( + %acc: !pto.vmi.vreg<256xui16>, + %source: !pto.vmi.vreg<300xui8>, + %mask: !pto.vmi.mask<300xpred>) + -> !pto.vmi.vreg<256xui16> { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<300xui8>, + !pto.vmi.mask<300xpred> -> !pto.vmi.vreg<256xui16> + return %hist : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_dhist( +// CHECK-SAME: %[[ACC:.*]]: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: %[[MASK_B8:.*]] = pto.vmi.ensure_mask_granularity %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK: %[[HIST:.*]] = pto.vmi.dhist %[[ACC]], %[[SRC]], %[[MASK_B8]] +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: return %[[HIST]] diff --git a/test/lit/vmi/vmi_layout_assignment_expand_load.pto b/test/lit/vmi/vmi_layout_assignment_expand_load.pto new file mode 100644 index 0000000000..501b26b369 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_expand_load.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_expand_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_expand_load( +// CHECK-SAME: %[[SRC:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %[[PASSTHRU:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.expand_load %[[SRC]][%[[OFFSET]]], %[[MASK]], %[[PASSTHRU]] +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto new file mode 100644 index 0000000000..101f0f9254 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> + + func.func @caller(%x: !pto.vmi.vreg<128xf32>) { + %r = call @external(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto new file mode 100644 index 0000000000..ffb994287a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto @@ -0,0 +1,15 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external_vmi(!pto.vmi.vreg<128xf32>) +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto new file mode 100644 index 0000000000..384d0d1171 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @external_i32(i32) -> i32 + + func.func @vmi_layout_assignment_external_decl_preserve( + %input: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + return %input : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: module +// CHECK: func.func private @external_i32(i32) -> i32 +// CHECK-LABEL: func.func @vmi_layout_assignment_external_decl_preserve( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto new file mode 100644 index 0000000000..6ef79fb5f1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f32_f8_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// LOWER-COUNT-2: pto.vldsx2 {{.*}} "DINTLV_B32" +// LOWER-COUNT-2: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto new file mode 100644 index 0000000000..7e1d1b293f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f8_compute_f8( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %off: index) { + %x8 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %dst[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// ASSIGN: %[[X8:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X8]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y32:.*]] = pto.vmi.mulf %[[X32]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[Y8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// LOWER: pto.vlds +// LOWER-COUNT-4: pto.vcvt {{.*}} {part = "P{{[0-3]}}"} +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_fma.pto b/test/lit/vmi/vmi_layout_assignment_fma.pto new file mode 100644 index 0000000000..c40b09c471 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_fma.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_fma( + %lhs: !pto.vmi.vreg<64xf32>, + %rhs: !pto.vmi.vreg<64xf32>, + %acc: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_fma( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_gather.pto b/test/lit/vmi/vmi_layout_assignment_gather.pto new file mode 100644 index 0000000000..a63919bf6f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_gather.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xi32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_gather( +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.gather %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..2f881b2c71 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xbf16> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + return %out : !pto.vmi.vreg<256xbf16> + } + + func.func @vmi_layout_assignment_group_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[E2B:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B]], %[[E2B]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC32:.*]]: !pto.ptr, %[[OFF32:.*]]: index) +// CHECK: %[[E2B32:.*]] = pto.vlds %[[SRC32]][%[[OFF32]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B32]] : !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto new file mode 100644 index 0000000000..400c10093a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_broadcast_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %dense_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_out[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]] +// ASSIGN: pto.vmi.group_store %[[YSUM]] +// ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[H]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vmul +// LOWER: pto.vmul +// LOWER: pto.vcgadd +// LOWER: pto.vsts +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto new file mode 100644 index 0000000000..8ac9e7dc06 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_broadcast_slots8( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<1024xf32> { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32> + return %out : !pto.vmi.vreg<1024xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_broadcast +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load.pto b/test/lit/vmi/vmi_layout_assignment_group_load.pto new file mode 100644 index 0000000000..864683cb04 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load( + %source: !pto.ptr, + %row_stride: index) -> !pto.vmi.vreg<512xf32> { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_load( +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_load +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto new file mode 100644 index 0000000000..2953ab1989 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_block8_truncf( + %src: !pto.ptr, + %sum_dst: !pto.ptr, + %dense_dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %src[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_load_block8_truncf( +// CHECK: pto.vsldb +// CHECK: pto.vcgadd +// CHECK: pto.vintlv +// CHECK: pto.vdintlv +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto new file mode 100644 index 0000000000..8b60309aa6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_compact_stride12_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride12 = arith.constant 12 : index + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan + // CHECK-SAME: stable gather fallback is not implemented + %x = pto.vmi.group_load %base[%off], %stride12 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto new file mode 100644 index 0000000000..60a0e75884 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s16_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %x = pto.vmi.group_load %base[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// LOWER-COUNT-2: pto.vsldb +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..61b30fd778 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride20 = arith.constant 20 : index + %x = pto.vmi.group_load %base[%off], %stride20 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto new file mode 100644 index 0000000000..cb6011a650 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[YSUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-4: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto new file mode 100644 index 0000000000..34de7b5064 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..43b566f895 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s32_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride34 = arith.constant 34 : index + %x = pto.vmi.group_load %base[%off], %stride34 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 32 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto new file mode 100644 index 0000000000..fddd344cf6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_maxf_quant( + %src: !pto.ptr, + %scale_out: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %c256 = arith.constant 256 : index + %eps = arith.constant 1.000000e-04 : f32 + %fp8_max = arith.constant 4.480000e+02 : f32 + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %eps2 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<2xf32> + %amax = pto.vmi.maxf %amax_raw, %eps2 + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + %fp8_max2 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max2 + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %scale_out[%off], %c8 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[ABS:.*]] = pto.vmi.absf %[[X]] +// ASSIGN: %[[AMAX_RAW:.*]] = pto.vmi.group_reduce_maxf %[[ABS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.divf +// ASSIGN-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALE]] +// ASSIGN: %[[SCALE_VEC:.*]] = pto.vmi.group_broadcast %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q:.*]] = pto.vmi.divf %[[X]], %[[SCALE_VEC]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q_SPLIT:.*]] = pto.vmi.ensure_layout %[[Q]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// LOWER: pto.vcgmax +// LOWER: pto.vmax +// LOWER: pto.vsel +// LOWER: pto.vdiv +// LOWER: pto.vdintlv +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..eb9296da40 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16>, + %mask: !pto.vmi.mask<256xpred>) + -> !pto.vmi.vreg<4xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<4xf16> + return %out : !pto.vmi.vreg<4xf16> + } + + func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16>, + %mask: !pto.vmi.mask<768xpred>) + -> !pto.vmi.vreg<12xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16>, !pto.vmi.mask<768xpred> + -> !pto.vmi.vreg<12xf16> + return %out : !pto.vmi.vreg<12xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( +// CHECK-SAME: %arg0: !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> +// CHECK: %[[SRC4:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MASK4_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MASK4:.*]] = pto.vmi.ensure_mask_granularity %[[MASK4_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: %[[OUT4:.*]] = pto.vmi.group_reduce_addf %[[SRC4]], %[[MASK4]] +// CHECK-SAME: -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> +// CHECK: return %[[OUT4]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( +// CHECK-SAME: %arg0: !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> +// CHECK: %[[SRC12:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: %[[MASK12_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK: %[[MASK12:.*]] = pto.vmi.ensure_mask_granularity %[[MASK12_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<768xb16, #pto.vmi.layout> +// CHECK: %[[OUT12:.*]] = pto.vmi.group_reduce_addf %[[SRC12]], %[[MASK12]] +// CHECK-SAME: -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> +// CHECK: return %[[OUT12]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto new file mode 100644 index 0000000000..04dbe3952c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s12_invalid( + %source: !pto.vmi.vreg<96xf32>, + %mask: !pto.vmi.mask<96xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK-SAME: VMI types: operand#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> + // CHECK-SAME: operand#1=!pto.vmi.mask<96xb32, #pto.vmi.layout> + // CHECK-SAME: result#0=!pto.vmi.vreg<8xf32, #pto.vmi.layout> + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto new file mode 100644 index 0000000000..d73d66570b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// LOWER: %[[LO:.*]], %[[HI:.*]] = pto.vdintlv +// LOWER: %[[MLO:.*]], %[[MHI:.*]] = pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: %[[SLO:.*]] = pto.vcgadd %[[LO]], %[[MLO]] +// LOWER: %[[SHI:.*]] = pto.vcgadd %[[HI]], %[[MHI]] +// LOWER: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg4[%arg5], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto new file mode 100644 index 0000000000..2e53e6d7a3 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %rows16 = pto.vmi.truncf %rows32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows16, %dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE:.*]] = pto.vmi.ensure_layout %arg0 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32_SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[B16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto new file mode 100644 index 0000000000..59fee4b8b9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s256( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<2xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<2xf32> + return %out : !pto.vmi.vreg<2xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s256( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto new file mode 100644 index 0000000000..d564ea6544 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// LOWER-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C2]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C6]] +// LOWER: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto new file mode 100644 index 0000000000..93d471291b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<16xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 16} + : !pto.vmi.vreg<16xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<16xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<16xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// LOWER-COUNT-8: pto.vdintlv +// LOWER-COUNT-8: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-8: pto.vcgadd +// LOWER: %[[STORE_MASK0:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[%arg17], %[[STORE_MASK0]] +// LOWER: %[[STORE_MASK1:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[{{.*}}], %[[STORE_MASK1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto new file mode 100644 index 0000000000..443b8d822c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_store( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto new file mode 100644 index 0000000000..372b445342 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( + %src: memref<256xf32>, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<6xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<6xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<6xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER-DAG: %[[C48:.*]] = arith.constant 48 : i32 +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.plt_b32 %[[C48]] : i32 -> !pto.mask, i32 +// LOWER: %[[SLOTS:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd {{.*}}, {{.*}}, %[[SLOTS]] +// LOWER: %[[STORE:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER: pto.vsts {{.*}}, {{.*}}, %[[STORE]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( +// ASSIGN: %[[PX:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto new file mode 100644 index 0000000000..de04988cb0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid( + %source: !pto.vmi.vreg<192xf32>, + %mask: !pto.vmi.mask<192xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.group_reduce_addf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: requires source and result to have the same physical arity + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<6xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<6xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto new file mode 100644 index 0000000000..1f10d8a6ee --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<8xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s64( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto new file mode 100644 index 0000000000..0e298f6cf1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<512xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// LOWER-COUNT-8: pto.vcgadd +// LOWER-COUNT-8: pto.vdup {{.*}} {position = "LOWEST"} +// LOWER-COUNT-8: pto.vmul +// LOWER-COUNT-8: pto.vcgadd +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto new file mode 100644 index 0000000000..ec2efac143 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_tail_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<6xf32> + pto.vmi.group_store %sum, %dst[%off], %c8 {num_groups = 6} + : !pto.vmi.vreg<6xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<6xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// LOWER-COUNT-6: pto.vlds +// LOWER-COUNT-6: pto.vcgadd +// LOWER-COUNT-6: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto new file mode 100644 index 0000000000..ff0e67b9ad --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_truncf( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c16 = arith.constant 16 : index + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> + pto.vmi.group_store %sum16, %dst[%off], %c16 {num_groups = 8} + : !pto.vmi.vreg<8xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_truncf( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM16]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf16, #pto.vmi.layout>, !pto.ptr, !pto.mask -> !pto.vreg<128xf16> +// LOWER: pto.pge_b16 "PAT_VL1" +// LOWER: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto new file mode 100644 index 0000000000..66f6c5fe47 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<8xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto new file mode 100644 index 0000000000..f5af0d2ed1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_slots8_store( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// LOWER: %[[SUM:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg2[%arg3], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto new file mode 100644 index 0000000000..7d4e4b6bc5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @typed_group_reduce_assignment( + %f16: !pto.vmi.vreg<256xf16>, + %mf16: !pto.vmi.mask<256xpred>, + %i16: !pto.vmi.vreg<128xi16>, + %mi16: !pto.vmi.mask<128xpred>, + %i32: !pto.vmi.vreg<128xi32>, + %mi32: !pto.vmi.mask<128xpred>) { + %sum_f16 = pto.vmi.group_reduce_addf %f16, %mf16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf16> + %wide_i16 = pto.vmi.extsi %i16 + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum_i16 = pto.vmi.group_reduce_addi %wide_i16, %mi16 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xi32> + %sum_i32 = pto.vmi.group_reduce_addi %i32, %mi32 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xi32> + return + } +} + +// CHECK-LABEL: func.func @typed_group_reduce_assignment( +// CHECK: %[[F16_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MF16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MF16_B16:.*]] = pto.vmi.ensure_mask_granularity %[[MF16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addf %[[F16_SPLIT]], %[[MF16_B16]] +// CHECK-SAME: -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> +// CHECK: %[[WIDE_I16:.*]] = pto.vmi.extsi %arg2 +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[WIDE_I16]], %[[MI16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> +// CHECK: %[[I32_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI32_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[I32_SPLIT]], %[[MI32_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..b8135eeb05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-pre-assignment-combine -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 16} + : !pto.ptr -> !pto.vmi.vreg<16xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 16} + : !pto.vmi.vreg<16xf16> -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16_deint2( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xbf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xbf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xbf16> -> !pto.vmi.vreg<256xbf16> + return %out : !pto.vmi.vreg<256xbf16> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32_deint2( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<128xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[E2B0:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: %[[OFF8:.*]] = arith.addi %[[OFF]], %[[C8]] : index +// CHECK: %[[E2B1:.*]] = pto.vlds %[[SRC]][%[[OFF8]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: return %[[E2B0]], %[[E2B1]] : !pto.vreg<128xf16>, !pto.vreg<128xf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16_deint2 +// CHECK-SAME: (%[[SRC2:.*]]: !pto.ptr, %[[OFF2:.*]]: index) +// CHECK: %[[E2B2:.*]] = pto.vlds %[[SRC2]][%[[OFF2]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B2]], %[[E2B2]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC3:.*]]: !pto.ptr, %[[OFF3:.*]]: index) +// CHECK: %[[E2B3:.*]] = pto.vlds %[[SRC3]][%[[OFF3]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B3]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32_deint2 +// CHECK-SAME: (%[[SRC4:.*]]: !pto.ptr, %[[OFF4:.*]]: index) +// CHECK: %[[E2B4:.*]] = pto.vlds %[[SRC4]][%[[OFF4]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B4]], %[[E2B4]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto new file mode 100644 index 0000000000..da550785b5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_no_e2b( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } + +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_no_e2b +// CHECK-NOT: E2B_B16 +// CHECK-NOT: E2B_B32 +// CHECK: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto new file mode 100644 index 0000000000..edcddb8d50 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 4} + : !pto.ptr -> !pto.vmi.vreg<4xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 4} + : !pto.vmi.vreg<4xf16> -> !pto.vmi.vreg<64xf16> + return %out : !pto.vmi.vreg<64xf16> + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support +// CHECK-SAME: requires full result physical chunks +// CHECK-NOT: E2B_B16 diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto new file mode 100644 index 0000000000..153408b78b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<8xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots1( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<8xf32> { + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_group_slot_load_extui( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<8xui32> { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %wide = pto.vmi.extui %narrow + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + return %wide : !pto.vmi.vreg<8xui32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK: pto.vmi.group_store %[[OUT]] +// CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_extui( +// CHECK-SAME: -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> +// CHECK: %[[NARROW:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> +// CHECK: %[[WIDE:.*]] = pto.vmi.extui %[[NARROW]] +// CHECK-SAME: !pto.vmi.vreg<8xui8, #pto.vmi.layout> -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> +// CHECK: return %[[WIDE]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto new file mode 100644 index 0000000000..075a9d58a3 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slot_load_dual_layout( + %rhs_base: !pto.ptr, + %source16: !pto.vmi.vreg<128xf32>, + %mask16: !pto.vmi.mask<128xpred>, + %source64: !pto.vmi.vreg<512xf32>, + %mask64: !pto.vmi.mask<512xpred>, + %out16: !pto.ptr, + %out64: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %rhs16 = pto.vmi.group_slot_load %rhs_base[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sum16 = pto.vmi.group_reduce_addf %source16, %mask16 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %outv16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %outv16, %out16[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + + %rhs64 = pto.vmi.group_slot_load %rhs_base[%off], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sum64 = pto.vmi.group_reduce_addf %source64, %mask64 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %outv64 = pto.vmi.addf %sum64, %rhs64 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %outv64, %out64[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// LOWER: pto.pge_b32 "PAT_VL8" +// LOWER: pto.vsldb +// LOWER: pto.vsts {{.*}}, %arg21[%arg23], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-COUNT-8: pto.vsldb +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto new file mode 100644 index 0000000000..397afcb6da --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..6dbb34ad97 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( + %src: !pto.ptr, %off: index) { + %c2 = arith.constant 2 : index + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto new file mode 100644 index 0000000000..99e91bf98c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slots_cf_join( + %cond: i1, + %src: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = scf.if %cond -> !pto.vmi.vreg<8xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + scf.yield %a : !pto.vmi.vreg<8xf32> + } else { + %b = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %b : !pto.vmi.vreg<8xf32> + } + %bias = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.addf %sum, %bias + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slots_cf_join( +// CHECK: %[[IF:.*]] = scf.if +// CHECK-SAME: -> (!pto.vreg<64xf32>) +// CHECK: pto.vldsx2 +// CHECK: pto.vcgadd +// CHECK: pto.vcgadd +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: else +// CHECK: pto.vsldb +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: %[[BIAS:.*]] = pto.vsldb +// CHECK: pto.vadd %[[IF]], %[[BIAS]] +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto new file mode 100644 index 0000000000..ccf77511d6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_fanout( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %sum_dst: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scaled_sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER: %[[FIRST_SUM:.*]] = pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: pto.vsts %[[FIRST_SUM]], %arg4[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER-COUNT-2: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg5[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto new file mode 100644 index 0000000000..ec00c48c70 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_scf_for( + %init: !pto.ptr, + %base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %acc0 = pto.vmi.group_slot_load %init[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<8xf32>) { + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %next : !pto.vmi.vreg<8xf32> + } + pto.vmi.group_store %acc, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: %[[ACC:.*]] = scf.for +// ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) +// ASSIGN-SAME: -> (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: scf.yield +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[ACC]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// LOWER: pto.vsldb +// LOWER: scf.for +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: scf.yield +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto new file mode 100644 index 0000000000..44625678ac --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_store_slots1_unit_stride( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_store_slots1_unit_stride( +// CHECK-COUNT-8: pto.vcgadd +// CHECK: pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto new file mode 100644 index 0000000000..4186b78dfa --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @caller( + %fn: (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32>, + %x: !pto.vmi.vreg<128xf32>) { + %r = func.call_indirect %fn(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed call requires a direct internal callee with a body diff --git a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto new file mode 100644 index 0000000000..d79cdfddba --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_iota_remat( + %base: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %iota, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %iota, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_iota_remat( +// ASSIGN-SAME: %[[BASE:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[IOTA_DEINT:.*]] = pto.vmi.iota %[[BASE]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[IOTA_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.ensure_layout %[[IOTA_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[IOTA_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_iota_remat( +// LOWER: pto.vci +// LOWER: pto.vcvt +// LOWER: pto.vadd +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto new file mode 100644 index 0000000000..9d64cffec2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_load_truncf( + %src: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_load_truncf_multi_use( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.store %wide, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf( +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[LOW]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[HIGH]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// LOWER: pto.vsts +// LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto new file mode 100644 index 0000000000..f3942119a7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_granularity_conflict_invalid( + %cond: i1, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %mask = scf.if %cond -> !pto.vmi.mask<128xpred> { + %m16 = pto.vmi.cmpf "olt", %a16, %b16 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.mask<128xpred> + scf.yield %m16 : !pto.vmi.mask<128xpred> + } else { + %m32 = pto.vmi.cmpf "olt", %a32, %b32 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %m32 : !pto.vmi.mask<128xpred> + } + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting mask granularities diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto new file mode 100644 index 0000000000..44ae6a19c5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( + %src: !pto.ptr, + %out32: !pto.ptr, + %out16: !pto.ptr, + %off: index) { + %c96 = arith.constant 96 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %out32[%off], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %out16[%off], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[X]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[X_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[H]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.pge_b32 "PAT_ALL" +// LOWER: pto.pge_b32 "PAT_VL32" +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.ppack +// LOWER: pto.ppack +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto new file mode 100644 index 0000000000..8e799c0704 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize | FileCheck %s --check-prefix=REMAT + +module { + func.func @vmi_layout_assignment_create_mask_remat( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_layout_assignment_constant_mask_remat( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// ASSIGN-SAME: %[[ACTIVE:.*]]: index +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// ASSIGN-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// ASSIGN: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[CM16:.*]] = pto.vmi.ensure_mask_granularity %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// REMAT-SAME: %[[ACTIVE:.*]]: index +// REMAT: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[M16]] +// REMAT: pto.vmi.select %[[M32]] +// REMAT-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// REMAT: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[CM16:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[CM16]] +// REMAT: pto.vmi.select %[[CM32]] +// REMAT-NOT: pto.vmi.ensure_mask_layout +// REMAT-NOT: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto new file mode 100644 index 0000000000..62ef723511 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_select_store( + %src: !pto.ptr, + %rhs: !pto.ptr, + %dense: !pto.ptr, + %masked: !pto.ptr, + %off: index) { + %c48 = arith.constant 48 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %y = pto.vmi.load %rhs[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %y + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %dense[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %masked[%off], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[X]], %[[Y]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[PASS:.*]] = pto.vmi.select %[[MASK]], %[[SUM]], %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[PASS]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.masked_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_mask_layout +// ASSIGN-NOT: pto.vmi.ensure_mask_granularity + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.plt_b32 +// LOWER: pto.vadd +// LOWER: pto.vsel +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto new file mode 100644 index 0000000000..fd487d017a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_use_ensure( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_mask_use_ensure( +// CHECK-SAME: %[[M:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load.pto b/test/lit/vmi/vmi_layout_assignment_masked_load.pto new file mode 100644 index 0000000000..286c92f6da --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_masked_load( +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.masked_load %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto new file mode 100644 index 0000000000..e179f5ccc6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_masked_load_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[ZERO:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.masked_load +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// LOWER-COUNT-4: pto.vsel +// LOWER-COUNT-4: pto.vsts +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto new file mode 100644 index 0000000000..debe3fd571 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_masked_load_group_tail_s32( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %mask = pto.vmi.create_group_mask %c25 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.ensure_layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN: pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return.pto b/test/lit/vmi/vmi_layout_assignment_multi_return.pto new file mode 100644 index 0000000000..380b0d0ef9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @multi_return( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @multi_return( +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto new file mode 100644 index 0000000000..4e9b2885fd --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @multi_return_conflict( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf8E4M3FN>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf8E4M3FN> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto new file mode 100644 index 0000000000..8bde94f611 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_non_load_s32_reduce( + %base: !pto.ptr, + %bias: f32, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %a = pto.vmi.load %base[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.addf %a, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// ASSIGN-SAME: %[[BASE:arg[0-9]+]]: !pto.ptr +// ASSIGN-SAME: %[[BIAS:arg[0-9]+]]: f32 +// ASSIGN: %[[A:.*]] = pto.vmi.load %[[BASE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[BIASV:.*]] = pto.vmi.broadcast %[[BIAS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.addf %[[A]], %[[BIASV]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// LOWER-COUNT-4: pto.vdup %arg1 +// LOWER-COUNT-4: pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: pto.vsts {{.*}}, %arg2[%arg3], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto new file mode 100644 index 0000000000..f1be94a798 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_packed_group_slots_truncf_invalid( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) + %h = pto.vmi.truncf %sum + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> + pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf16>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto new file mode 100644 index 0000000000..968aeb1e05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto new file mode 100644 index 0000000000..de71e01d6a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addf %arg0, %arg1, %arg2 +// CHECK-SAME: reassoc +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto new file mode 100644 index 0000000000..82a516b114 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addi( + %source: !pto.vmi.vreg<64xi32>, + %init: !pto.vmi.vreg<1xi32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> + return %out : !pto.vmi.vreg<1xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addi( +// CHECK-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %[[INIT:.*]]: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addi %[[SOURCE]], %[[INIT]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto new file mode 100644 index 0000000000..51f8180ef0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_maxf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } + + func.func @vmi_layout_assignment_reduce_minf( + %source: !pto.vmi.vreg<128xf16>, + %init: !pto.vmi.vreg<1xf16>, + %mask: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<1xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> + return %out : !pto.vmi.vreg<1xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_maxf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[MAX:.*]] = pto.vmi.reduce_maxf %arg0, %arg1, %arg2 +// CHECK: return %[[MAX]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_minf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK: %[[MASK:.*]] = pto.vmi.ensure_mask_granularity %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MIN:.*]] = pto.vmi.reduce_minf %arg0, %arg1, %[[MASK]] +// CHECK: return %[[MIN]] diff --git a/test/lit/vmi/vmi_layout_assignment_scatter.pto b/test/lit/vmi/vmi_layout_assignment_scatter.pto new file mode 100644 index 0000000000..b920cf4da4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scatter.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scatter( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scatter( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.scatter %arg0, %arg1[%arg2], %arg3 +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto new file mode 100644 index 0000000000..3bd81dca8c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_execute_region( + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.execute_region -> !pto.vmi.vreg<128xf32> { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.execute_region -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// LOWER: %[[RESULT:.*]]:2 = scf.execute_region -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_scf_for.pto new file mode 100644 index 0000000000..b63563216b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_for.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_for(%a: !pto.vmi.vreg<128xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + %sum = pto.vmi.addf %result, %result + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_for( +// CHECK: %[[INIT:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[RESULT:.*]] = scf.for +// CHECK-SAME: iter_args(%[[ACC:.*]] = %[[INIT]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[ACC]], %[[ACC]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[RESULT]], %[[RESULT]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_if.pto b/test/lit/vmi/vmi_layout_assignment_scf_if.pto new file mode 100644 index 0000000000..f86107920a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_if.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_if( +// CHECK: scf.if +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto new file mode 100644 index 0000000000..24ea65503e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_index_switch( + %selector: index, + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.index_switch %selector -> !pto.vmi.vreg<128xf32> + case 0 { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + default { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// ASSIGN-SAME: %[[SELECTOR:.*]]: index +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.index_switch %[[SELECTOR]] -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: default +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// LOWER: %[[RESULT:.*]]:2 = scf.index_switch {{.*}} -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: default +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_while.pto b/test/lit/vmi/vmi_layout_assignment_scf_while.pto new file mode 100644 index 0000000000..917bf1762f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_while.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_while( + %input: !pto.vmi.vreg<128xf16>, + %keep_going: i1) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.while (%value = %wide) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + scf.condition(%keep_going) %value : !pto.vmi.vreg<128xf32> + } do { + ^bb0(%value: !pto.vmi.vreg<128xf32>): + scf.yield %value : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_while( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.while (%[[VALUE:.*]] = %[[WIDE]]) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.condition(%arg1) %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: ^bb0(%[[AFTER:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: scf.yield %[[AFTER]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_while( +// LOWER: %[[RESULT:.*]]:2 = scf.while +// LOWER-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.condition(%arg1) {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_store_ensure.pto b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto new file mode 100644 index 0000000000..430fff7fda --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_store_ensure( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_store_ensure( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[WIDE]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_store_ensure( +// LOWER: %[[EVEN:.*]] = pto.vcvt +// LOWER: %[[ODD:.*]] = pto.vcvt +// LOWER: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]] +// LOWER: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]] +// LOWER: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[SUM0]], %[[SUM1]] +// LOWER: pto.vsts %[[D0]] +// LOWER: pto.vsts %[[D1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto new file mode 100644 index 0000000000..8908036648 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_truncf_ensure( + %wide: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// ASSIGN-SAME: %[[WIDE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// LOWER-SAME: %[[D0:arg[0-9]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[D1:arg[0-9]+]]: !pto.vreg<64xf32> +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[LOW]]{{.*}}part = "EVEN" +// LOWER: %[[ODD:.*]] = pto.vcvt %[[HIGH]]{{.*}}part = "ODD" +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto new file mode 100644 index 0000000000..4a36d5882a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( +// ASSIGN: %[[NARROW:.*]] = pto.vmi.trunci +// ASSIGN-SAME: -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[NARROW]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( +// LOWER-NOT: pto.vcvt +// LOWER-NOT: pto.vpack +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_widen.pto b/test/lit/vmi/vmi_layout_assignment_widen.pto new file mode 100644 index 0000000000..eceedcb711 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_widen( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %eb = pto.vmi.extf %b : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %cmp, %sum, %ea + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_widen( +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto new file mode 100644 index 0000000000..95f5becf6b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( + %src: !pto.ptr, + %k1: !pto.vmi.vreg<128xf32>, + %init0: !pto.vmi.vreg<1xf32>, + %init1: !pto.vmi.vreg<1xf32>, + %out0: !pto.ptr, + %out1: !pto.ptr, + %off: index) { + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %a = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %t1 = pto.vmi.mulf %w, %k1 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %r0 = pto.vmi.reduce_addf %t1, %init0, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r0, %out0[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + %r = pto.vmi.reduce_addf %w, %init1, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r, %out1[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// ASSIGN-SAME: %arg1: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg2: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg3: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: %[[A:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[W:.*]] = pto.vmi.extf %[[A]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[T1:.*]] = pto.vmi.mulf %[[W]], %arg1 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[T1_DENSE:.*]] = pto.vmi.ensure_layout %[[T1]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R0:.*]] = pto.vmi.reduce_addf %[[T1_DENSE]], %arg2, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R0]] +// ASSIGN: %[[W_DENSE:.*]] = pto.vmi.ensure_layout %[[W]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R:.*]] = pto.vmi.reduce_addf %[[W_DENSE]], %arg3, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vmul +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto new file mode 100644 index 0000000000..9bf53802d4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_f16_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %dense: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x16 = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + pto.vmi.store %x32, %dense[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X32_DENSE]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_factor_invalid.pto b/test/lit/vmi/vmi_layout_factor_invalid.pto new file mode 100644 index 0000000000..b908700333 --- /dev/null +++ b/test/lit/vmi/vmi_layout_factor_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_factor_invalid( + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout expected factor to be 2 or 4 diff --git a/test/lit/vmi/vmi_layout_fold_deint4.pto b/test/lit/vmi/vmi_layout_fold_deint4.pto new file mode 100644 index 0000000000..6cd26b29e4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_deint4.pto @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %value_c, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_layout_fold_masked_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[VALUE]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// FOLD-LABEL: func.func @vmi_layout_fold_masked_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts + +// LOWER-LABEL: func.func @vmi_layout_fold_masked_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_load.pto b/test/lit/vmi/vmi_layout_fold_load.pto new file mode 100644 index 0000000000..804c522df1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_load.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_load_all_ensures( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split0 = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split1 = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %split0, %split1 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_load_keeps_mixed_use( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %load, %split + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_load_all_ensures( +// FOLD: %[[LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return %[[LOAD]], %[[LOAD]] + +// FOLD-LABEL: func.func @vmi_layout_fold_load_keeps_mixed_use( +// FOLD: %[[MIXED_LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[MIXED_LOAD]] +// FOLD: return %[[MIXED_LOAD]], %[[SPLIT]] + +// LOWER-LABEL: func.func @vmi_layout_fold_load_all_ensures( +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 +// LOWER-SAME: "DINTLV_B32" +// LOWER: return %[[LOW]], %[[HIGH]], %[[LOW]], %[[HIGH]] diff --git a/test/lit/vmi/vmi_layout_fold_masked_store.pto b/test/lit/vmi/vmi_layout_fold_masked_store.pto new file mode 100644 index 0000000000..4fc8cbee83 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_masked_store.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_masked_store( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_masked_store( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_masked_store( +// LOWER-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[M0:[^,]+]]: !pto.mask +// LOWER-SAME: %[[M1:[^,]+]]: !pto.mask +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// LOWER: %[[ML:.*]], %[[MH:.*]] = pto.pintlv_b32 %[[M0]], %[[M1]] +// LOWER: pto.vsts %[[LOW]] +// LOWER-SAME: %[[ML]] +// LOWER: pto.vsts %[[HIGH]] +// LOWER-SAME: %[[MH]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_store.pto b/test/lit/vmi/vmi_layout_fold_store.pto new file mode 100644 index 0000000000..484b1c636b --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_store.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_store( + %src: !pto.vmi.vreg<128xf16>, + %scale: f32, + %out1: !pto.ptr, + %out2: !pto.ptr, + %offset: index) { + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wide, %scale_v + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %prod, %out1[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %wide, %out2[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + +} + +// FOLD-LABEL: func.func @vmi_layout_fold_store( +// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// FOLD: %[[SCALE:.*]] = pto.vmi.broadcast +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[PROD:.*]] = pto.vmi.mulf %[[WIDE]], %[[SCALE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[PROD]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[WIDE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_store( +// LOWER: %[[SCALE0:.*]] = pto.vdup +// LOWER: %[[SCALE1:.*]] = pto.vdup +// LOWER: %[[WIDE0:.*]] = pto.vcvt +// LOWER: %[[WIDE1:.*]] = pto.vcvt +// LOWER: %[[PROD0:.*]] = pto.vmul %[[WIDE0]], %[[SCALE0]] +// LOWER: %[[PROD1:.*]] = pto.vmul %[[WIDE1]], %[[SCALE1]] +// LOWER-NOT: pto.vintlv +// LOWER: pto.vstsx2 %[[PROD0]], %[[PROD1]] +// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto new file mode 100644 index 0000000000..8a69c96385 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_group_slots( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_gate_bitcast_group_slots( +// CHECK: pto.vmi.bitcast diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto new file mode 100644 index 0000000000..2acec47cd2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_support_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support + // CHECK-SAME: requires matching logical bit footprint in every physical chunk + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto new file mode 100644 index 0000000000..fbabafa3e3 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_extf_support_invalid( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered layout support + // CHECK-SAME: requires contiguous or deinterleaved source layout and deinterleaved f32 result layout with block_elems=1 + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" + %out = pto.vmi.extf %source + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto new file mode 100644 index 0000000000..b2124db3c2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_broadcast_support_invalid( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support + // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" + %out = pto.vmi.group_broadcast %source {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto new file mode 100644 index 0000000000..a14ff20a0b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_load_support_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 layout support + // CHECK-SAME: block8 strided group_load requires constant positive row_stride divisible by 8 f32 elements + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_load" + %out = pto.vmi.group_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto new file mode 100644 index 0000000000..9bebc83b97 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_slots1_support_invalid( + %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=1 support group sizes that are multiples of one physical chunk + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto new file mode 100644 index 0000000000..172600b145 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_support_invalid( + %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + !pto.vmi.mask<96xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto new file mode 100644 index 0000000000..9c34f6a261 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_slot_load_support_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support + // CHECK-SAME: slots=8 group_slot_load requires constant unit source_group_stride + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto new file mode 100644 index 0000000000..b6f309f693 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_slots2_invalid( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support + // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_group_reduce_slots2_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto new file mode 100644 index 0000000000..2676b55a1f --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_support_invalid( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support + // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride + // CHECK: note: see current operation: "pto.vmi.group_store" + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto new file mode 100644 index 0000000000..4aa1f30cbb --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_ensure_layout_shape_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_ensure_mask_layout_shape_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization support + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto new file mode 100644 index 0000000000..1cf2548b79 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_helper_support_invalid( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { + %bad = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support +// CHECK-SAME: unsupported source/result layout pair diff --git a/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto new file mode 100644 index 0000000000..14e874beac --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_store_deint_tail_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory layout support + // CHECK-SAME: requires arity divisible by layout factor + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_support.pto b/test/lit/vmi/vmi_layout_gate_support.pto new file mode 100644 index 0000000000..9a48ea0721 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_support.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_gate_support( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_gate_support( +// CHECK: pto.vmi.group_reduce_addf diff --git a/test/lit/vmi/vmi_layout_gate_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto new file mode 100644 index 0000000000..1b1bfdfb52 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout diff --git a/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto new file mode 100644 index 0000000000..79425740d8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_mask_invalid( + %m: !pto.vmi.mask<128xpred>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout +// CHECK-SAME: !pto.vmi.mask with b8/b16/b32 granularity plus layout diff --git a/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto new file mode 100644 index 0000000000..385aa56191 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_truncf_support_invalid( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered layout support + // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" + %out = pto.vmi.truncf %source + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto new file mode 100644 index 0000000000..7494367606 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_physical_state = [{nested = !pto.vreg<64xf32>}] +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..78549ed3e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.mask<128xpred> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_valid.pto b/test/lit/vmi/vmi_layout_gate_valid.pto new file mode 100644 index 0000000000..ebc5778f34 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_valid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir + +module { + func.func @vmi_layout_gate_valid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %sel = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto new file mode 100644 index 0000000000..0f3717ee4b --- /dev/null +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_group_slots_invalid( + %arg0: !pto.vmi.vreg<10xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout requires slots to be omitted or positive diff --git a/test/lit/vmi/vmi_layout_rematerialize_data.pto b/test/lit/vmi/vmi_layout_rematerialize_data.pto new file mode 100644 index 0000000000..22a03d88a5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_data.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_data( + %scalar: f32, + %base: f32) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %broadcast_deint = pto.vmi.ensure_layout %broadcast + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %iota_deint = pto.vmi.ensure_layout %iota + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %constant_deint = pto.vmi.ensure_layout %constant + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + return %broadcast_deint, %iota_deint, %constant_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_keeps_load_helper( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %load = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %load_deint = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %load_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_data( +// CHECK: %[[BCAST:.*]] = pto.vmi.broadcast %arg0 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[IOTA:.*]] = pto.vmi.iota %arg1 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[CONST:.*]] = "pto.vmi.constant"(){{.*}}dense<1.000000e+00> : tensor<128xf32>{{.*}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_layout +// CHECK: return %[[BCAST]], %[[IOTA]], %[[CONST]] + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_keeps_load_helper( +// CHECK: %[[LOAD:.*]] = pto.vmi.load +// CHECK: %[[LOAD_DEINT:.*]] = pto.vmi.ensure_layout %[[LOAD]] +// CHECK: return %[[LOAD_DEINT]] diff --git a/test/lit/vmi/vmi_layout_rematerialize_mask.pto b/test/lit/vmi/vmi_layout_rematerialize_mask.pto new file mode 100644 index 0000000000..6c3bb60053 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_mask.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_mask(%active: index) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_b16 = pto.vmi.ensure_mask_granularity %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %group_mask = pto.vmi.create_group_mask %active + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %group_mask_deint = pto.vmi.ensure_mask_layout %group_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %constant_mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %constant_mask_b16 = pto.vmi.ensure_mask_granularity %constant_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + + return %mask_b16, %mask_deint, %group_mask_deint, %constant_mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_mask( +// CHECK: %[[M16:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MDEINT:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[GMDEINT:.*]] = pto.vmi.create_group_mask %arg0{{.*}}group_size = 16{{.*}}num_groups = 8{{.*}}index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"(){{.*}}dense : tensor<128xi1>{{.*}}!pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity +// CHECK: return %[[M16]], %[[MDEINT]], %[[GMDEINT]], %[[CM16]] diff --git a/test/lit/vmi/vmi_layout_rematerialize_relation.pto b/test/lit/vmi/vmi_layout_rematerialize_relation.pto new file mode 100644 index 0000000000..b9174c2527 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_relation.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize -canonicalize | FileCheck %s --check-prefix=REMAT +// RUN: pto-test-opt %s -vmi-layout-rematerialize -canonicalize -vmi-layout-fold -canonicalize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_rematerialize_direct_ext( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d2 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %x32_d4 = pto.vmi.ensure_layout %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y = pto.vmi.truncf %x32_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_mulf_chain( + %lhs: !pto.ptr, %rhs: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> { + %lhs16 = pto.vmi.load %lhs[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %rhs16 = pto.vmi.load %rhs[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %lhs32_d2 = pto.vmi.extf %lhs16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %rhs32_d2 = pto.vmi.extf %rhs16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mul_d2 = pto.vmi.mulf %lhs32_d2, %rhs32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mul_d4 = pto.vmi.ensure_layout %mul_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y = pto.vmi.truncf %mul_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_ext_multi_consumer( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d2 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y16 = pto.vmi.truncf %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d4 = pto.vmi.ensure_layout %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y8 = pto.vmi.truncf %x32_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y16, %y8 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_trunci_source_ensure( + %x32_d4: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> { + %x32_d2 = pto.vmi.ensure_layout %x32_d4 + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %y16 = pto.vmi.trunci %x32_d2 + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + return %y16 : !pto.vmi.vreg<256xui16, #pto.vmi.layout> + } +} + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_direct_ext( +// REMAT: %[[X16:.*]] = pto.vmi.load +// REMAT: %[[X16_D2:.*]] = pto.vmi.ensure_layout %[[X16]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// REMAT: %[[X32_D4:.*]] = pto.vmi.extf %[[X16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_mulf_chain( +// REMAT-DAG: %[[LHS_D2:.*]] = pto.vmi.ensure_layout +// REMAT-DAG: %[[RHS_D2:.*]] = pto.vmi.ensure_layout +// REMAT-DAG: %[[LHS_D4:.*]] = pto.vmi.extf %[[LHS_D2]] +// REMAT-DAG: %[[RHS_D4:.*]] = pto.vmi.extf %[[RHS_D2]] +// REMAT: %[[MUL_D4:.*]] = pto.vmi.mulf %[[LHS_D4]], %[[RHS_D4]] +// REMAT-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[MUL_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_ext_multi_consumer( +// REMAT: %[[X16:.*]] = pto.vmi.load +// REMAT: %[[X32_D2:.*]] = pto.vmi.extf %[[X16]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D2]] +// REMAT: %[[X16_D2:.*]] = pto.vmi.ensure_layout %[[X16]] +// REMAT: %[[X32_D4:.*]] = pto.vmi.extf %[[X16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_trunci_source_ensure( +// REMAT-SAME: %[[X32_D4:.*]]: !pto.vmi.vreg<256xi32, #pto.vmi.layout> +// REMAT: %[[Y16_D2:.*]] = pto.vmi.trunci %[[X32_D4]] +// REMAT-SAME: !pto.vmi.vreg<256xi32, #pto.vmi.layout> -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// REMAT: pto.vmi.ensure_layout %[[Y16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// REMAT-NOT: !pto.vmi.vreg<256xi32, #pto.vmi.layout> -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_rematerialize_direct_ext( +// LOWER: pto.vldsx2 +// LOWER-SAME: "DINTLV_B16" +// LOWER-COUNT-2: pto.vcvt {{.*}} {part = "EVEN"} +// LOWER-COUNT-2: pto.vcvt {{.*}} {part = "ODD"} +// LOWER: pto.vcvt {{.*}} {part = "P0" +// LOWER: pto.vcvt {{.*}} {part = "P1" +// LOWER: pto.vcvt {{.*}} {part = "P2" +// LOWER: pto.vcvt {{.*}} {part = "P3" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_rematerialize_mulf_chain( +// LOWER-COUNT-2: pto.vldsx2 +// LOWER-SAME: "DINTLV_B16" +// LOWER: pto.vmul +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto new file mode 100644 index 0000000000..eb21fae758 --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto @@ -0,0 +1,324 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_materialization_addf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_muli( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %prod = pto.vmi.muli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %prod : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_single_ensure_kept( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_fma( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %acc_deint = pto.vmi.ensure_layout %acc + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %out = pto.vmi.fma %lhs_deint, %rhs_deint, %acc_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask = pto.vmi.cmpf "olt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpi( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %mask = pto.vmi.cmpi "slt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_select( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %true_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %false_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %true_deint = pto.vmi.ensure_layout %true_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %false_deint = pto.vmi.ensure_layout %false_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %selected = pto.vmi.select %mask_deint, %true_deint, %false_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %neg = pto.vmi.negf %src_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %neg : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary_integer( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %abs = pto.vmi.absi %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %abs : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_bitwise( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %and = pto.vmi.andi %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %and : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_shift( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %shift = pto.vmi.shli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %shift : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_not( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %not = pto.vmi.not %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %not : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_addf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SUM:.*]] = pto.vmi.addf %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SUM_DEINT:.*]] = pto.vmi.ensure_layout %[[SUM]] +// CHECK-SAME: #pto.vmi.layout +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SUM_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_muli( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[PROD:.*]] = pto.vmi.muli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[PROD_DEINT:.*]] = pto.vmi.ensure_layout %[[PROD]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[PROD_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_single_ensure_kept( +// CHECK: %[[LHS_DEINT:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK: %[[SUM2:.*]] = pto.vmi.addf %[[LHS_DEINT]], %arg1 +// CHECK: return %[[SUM2]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_fma( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[FMA:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[FMA_DEINT:.*]] = pto.vmi.ensure_layout %[[FMA]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[FMA_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPF:.*]] = pto.vmi.cmpf "olt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPF_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPF]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPF_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpi( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPI:.*]] = pto.vmi.cmpi "slt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPI_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPI]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPI_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_select( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[SELECT:.*]] = pto.vmi.select %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SELECT_DEINT:.*]] = pto.vmi.ensure_layout %[[SELECT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SELECT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NEG:.*]] = pto.vmi.negf %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[NEG_DEINT:.*]] = pto.vmi.ensure_layout %[[NEG]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NEG_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary_integer( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[ABS:.*]] = pto.vmi.absi %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[ABS_DEINT:.*]] = pto.vmi.ensure_layout %[[ABS]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[ABS_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_bitwise( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[AND:.*]] = pto.vmi.andi %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[AND_DEINT:.*]] = pto.vmi.ensure_layout %[[AND]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[AND_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_shift( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SHIFT:.*]] = pto.vmi.shli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[SHIFT_DEINT:.*]] = pto.vmi.ensure_layout %[[SHIFT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SHIFT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_not( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NOT:.*]] = pto.vmi.not %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[NOT_DEINT:.*]] = pto.vmi.ensure_layout %[[NOT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NOT_DEINT]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_mask.pto b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto new file mode 100644 index 0000000000..0effb48323 --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_mask_layout_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_mask_layout %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_mask_layout %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_and %lhs_deint, %rhs_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_granularity_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> { + %lhs_b16 = pto.vmi.ensure_mask_granularity %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %rhs_b16 = pto.vmi.ensure_mask_granularity %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %out = pto.vmi.mask_or %lhs_b16, %rhs_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb16, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_layout_unary( + %source: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %source_deint = pto.vmi.ensure_mask_layout %source + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_not %source_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_binary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_and %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_granularity_binary( +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_or %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_B16:.*]] = pto.vmi.ensure_mask_granularity %[[OUT]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: return %[[OUT_B16]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_unary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_not %arg0 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] diff --git a/test/lit/vmi/vmi_legalize_arith_select.pto b/test/lit/vmi/vmi_legalize_arith_select.pto new file mode 100644 index 0000000000..0661b6764e --- /dev/null +++ b/test/lit/vmi/vmi_legalize_arith_select.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-legalize-arith-select -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_legalize_arith_select_vreg( + %cond: i1, + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_legalize_arith_select_mask( + %cond: i1, + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %selected : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_vreg( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_mask( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.mask<128xb32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] diff --git a/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto new file mode 100644 index 0000000000..43aca3fd30 --- /dev/null +++ b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_concrete_without_layout_invalid( + %arg0: !pto.vmi.mask<128xb32>) { + return + } +} + +// CHECK: concrete mask granularity requires layout diff --git a/test/lit/vmi/vmi_mask_granularity_invalid.pto b/test/lit/vmi/vmi_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4d85cc9aa0 --- /dev/null +++ b/test/lit/vmi/vmi_mask_granularity_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_granularity_invalid( + %arg0: !pto.vmi.mask<128xb64, #pto.vmi.layout>) { + return + } +} + +// CHECK: expected granularity to be one of pred, b8, b16, b32 diff --git a/test/lit/vmi/vmi_mask_logic_invalid.pto b/test/lit/vmi/vmi_mask_logic_invalid.pto new file mode 100644 index 0000000000..49798b742b --- /dev/null +++ b/test/lit/vmi/vmi_mask_logic_invalid.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_and_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_and' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_or_granularity_mismatch( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_or' op requires all VMI mask values to have the same granularity + +// ----- + +module { + func.func @vmi_mask_xor_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_xor' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_not_granularity_mismatch( + %src: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %not = pto.vmi.mask_not %src + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_not' op requires all VMI mask values to have the same granularity diff --git a/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto new file mode 100644 index 0000000000..e7d949242e --- /dev/null +++ b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_pred_with_layout_invalid( + %arg0: !pto.vmi.mask<128xpred, #pto.vmi.layout>) { + return + } +} + +// CHECK: pred mask must not carry layout diff --git a/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4b3a672049 --- /dev/null +++ b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_masked_store_mask_granularity_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.masked_store' op requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_memory_element_type_invalid.pto b/test/lit/vmi/vmi_memory_element_type_invalid.pto new file mode 100644 index 0000000000..5f1b8133bf --- /dev/null +++ b/test/lit/vmi/vmi_memory_element_type_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_load_element_type_invalid(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: 'pto.vmi.load' op requires memory source element type to match VMI data element type + +// ----- + +module { + func.func @vmi_store_element_type_invalid( + %value: !pto.vmi.vreg<128xf16>, %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// CHECK: 'pto.vmi.store' op requires memory destination element type to match VMI data element type diff --git a/test/lit/vmi/vmi_min_max_integer_invalid.pto b/test/lit/vmi/vmi_min_max_integer_invalid.pto new file mode 100644 index 0000000000..71d0861e82 --- /dev/null +++ b/test/lit/vmi/vmi_min_max_integer_invalid.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_minf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.minf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_maxf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.maxf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_negf_integer_invalid.pto b/test/lit/vmi/vmi_negf_integer_invalid.pto new file mode 100644 index 0000000000..6b28584b64 --- /dev/null +++ b/test/lit/vmi/vmi_negf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_negf_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.negf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto new file mode 100644 index 0000000000..38575c8a47 --- /dev/null +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_op_verifier_basic( + %ptr: !pto.ptr, + %layouted: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<128xf32>, + !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf16>, + !pto.vmi.vreg<4xf32>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %f32 = arith.constant 1.000000e+00 : f32 + %f16 = arith.constant 1.000000e+00 : f16 + %active = arith.constant 64 : index + + %const = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %broadcast = pto.vmi.broadcast %f32 : f32 -> !pto.vmi.vreg<128xf32> + %broadcast16 = pto.vmi.broadcast %f16 : f16 -> !pto.vmi.vreg<128xf16> + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %mask_const = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + + %add = pto.vmi.addf %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %mask, %broadcast, %const + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ext = pto.vmi.extf %broadcast16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %trunc = pto.vmi.truncf %ext : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + + %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %slot_loaded = pto.vmi.group_slot_load %ptr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %slot_loaded, %ptr[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr + + %small = "pto.vmi.shuffle"(%broadcast) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<4xf32> + %split0, %split1 = "pto.vmi.channel_split"(%small) + : (!pto.vmi.vreg<4xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + %merged = "pto.vmi.channel_merge"(%split0, %split1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<4xf32> + + %ensure = pto.vmi.ensure_layout %layouted + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %layouted_ext = pto.vmi.extf %ensure + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf64, #pto.vmi.layout> + %layouted_trunc = pto.vmi.truncf %layouted_ext + : !pto.vmi.vreg<128xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_layout = pto.vmi.ensure_mask_layout %mask_b32 + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_granularity = pto.vmi.ensure_mask_granularity %mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %part0, %part1 = "pto.vmi.unpack"(%layouted) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %packed = "pto.vmi.pack"(%part0, %part1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %i0 = arith.constant 0 : i32 + %iv0 = pto.vmi.broadcast %i0 : i32 -> !pto.vmi.vreg<128xi32> + %iadd = pto.vmi.addi %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + %icmp = pto.vmi.cmpi "slt", %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.mask<128xpred> + + return %add, %cmp, %sel, %trunc, %merged, %layouted_trunc, %mask_layout, + %mask_granularity, %packed, %iadd, %icmp + : !pto.vmi.vreg<128xf32>, + !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf16>, + !pto.vmi.vreg<4xf32>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred> + } +} + +// CHECK-LABEL: func.func @vmi_op_verifier_basic +// CHECK: pto.vmi.broadcast +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.cmpf +// CHECK: pto.vmi.select +// CHECK: pto.vmi.extf +// CHECK: pto.vmi.truncf +// CHECK: pto.vmi.load +// CHECK: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_store +// CHECK: pto.vmi.store +// CHECK: pto.vmi.ensure_layout +// CHECK: pto.vmi.ensure_mask_layout +// CHECK: pto.vmi.ensure_mask_granularity +// CHECK: "pto.vmi.unpack" +// CHECK: "pto.vmi.pack" +// CHECK: pto.vmi.addi +// CHECK: pto.vmi.cmpi diff --git a/test/lit/vmi/vmi_pack_arity_invalid.pto b/test/lit/vmi/vmi_pack_arity_invalid.pto new file mode 100644 index 0000000000..4ba4eaa180 --- /dev/null +++ b/test/lit/vmi/vmi_pack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_pack_arity_invalid(%p0: !pto.vreg<64xf32>) { + %a = "pto.vmi.pack"(%p0) + : (!pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto b/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto new file mode 100644 index 0000000000..5092f998be --- /dev/null +++ b/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-pre-assignment-combine | FileCheck %s +// RUN: pto-test-opt %s -vmi-pre-assignment-combine -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @e2b_candidate(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 16} + : !pto.ptr -> !pto.vmi.vreg<16xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 16} + : !pto.vmi.vreg<16xf16> -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @not_e2b_candidate(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } + + func.func @not_e2b_consumer_deint2(%scale: !pto.ptr, + %x: !pto.vmi.vreg<256xf16>, %off: index) + -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %scale[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sf = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %x32 = pto.vmi.extf %x + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %x32, %sf + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } +} + +// CHECK-LABEL: func.func @e2b_candidate +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// CHECK-LABEL: func.func @not_e2b_candidate +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// CHECK-LABEL: func.func @not_e2b_consumer_deint2 +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// LOWER-LABEL: func.func @e2b_candidate +// LOWER: pto.vlds {{.*}} {dist = "E2B_B16"} + +// LOWER-LABEL: func.func @not_e2b_candidate +// LOWER-NOT: E2B_B32 +// LOWER: pto.vsldb +// LOWER: pto.vselr +// LOWER-NOT: pto.vmi. + +// LOWER-LABEL: func.func @not_e2b_consumer_deint2 +// LOWER-NOT: E2B_B32 +// LOWER: pto.vsldb +// LOWER: pto.vselr +// LOWER-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto new file mode 100644 index 0000000000..81805f2a28 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_helper_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto new file mode 100644 index 0000000000..be6a6414f9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_layout_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto new file mode 100644 index 0000000000..3d3727bdaa --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_mask_layout_invalid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface !pto.vmi.vreg or !pto.vmi.mask type diff --git a/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto new file mode 100644 index 0000000000..c5aa0676f0 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_non_vmi_op_invalid( + %a: !pto.vmi.vreg<128xf32>) { + %0 = builtin.unrealized_conversion_cast %a + : !pto.vmi.vreg<128xf32> to !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI typed value is used by a non-VMI semantic op diff --git a/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto new file mode 100644 index 0000000000..c2a3996eb9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_physical_invalid(%a: !pto.vreg<64xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO + +// ----- + +module { + func.func @vmi_producer_boundary_physical_op_invalid() { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto new file mode 100644 index 0000000000..8deed1cecb --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto new file mode 100644 index 0000000000..4163dcfb16 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_state = {nested = [!pto.vmi.vreg<128xf32>]} +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..8cd353ca13 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_valid.pto b/test/lit/vmi/vmi_producer_boundary_valid.pto new file mode 100644 index 0000000000..dee731bd1f --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_valid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-ir | FileCheck %s + +module { + func.func @vmi_producer_boundary_valid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %m: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<128xf32> { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %s = pto.vmi.select %m, %r, %a + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %s : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_producer_boundary_valid +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.select diff --git a/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto new file mode 100644 index 0000000000..7379984b50 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --enable-vmi %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vmi_ptoas_backend_required_invalid() { + return + } +} + +// CHECK: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto new file mode 100644 index 0000000000..771ae5904c --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%value: f32, %dst: !pto.ptr, %off: index) { + pto.vecscope { + %x = pto.vmi.broadcast %value : f32 -> !pto.vmi.vreg<128xf32> + %r = func.call @callee(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %r, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + return + } + } +} + +// CHECK-NOT: func.func private @callee +// CHECK-LABEL: func.func @caller +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vadd +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: func.call @callee +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_cli_control_flow.pto b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto new file mode 100644 index 0000000000..cd29782d10 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_control_flow( + %cond: i1, + %lhs: f32, + %rhs: f32, + %dst: !pto.ptr, + %offset: index) { + %lhs_v = pto.vmi.broadcast %lhs + : f32 -> !pto.vmi.vreg<128xf32> + %rhs_v = pto.vmi.broadcast %rhs + : f32 -> !pto.vmi.vreg<128xf32> + %chosen = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %lhs_v : !pto.vmi.vreg<128xf32> + } else { + scf.yield %rhs_v : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %chosen, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_control_flow +// CHECK: %[[LHS:.*]] = pto.vdup +// CHECK: %[[RHS:.*]] = pto.vdup +// CHECK: %[[CHOSEN:.*]] = arith.select {{.*}}, %[[LHS]], %[[RHS]] : !pto.vreg<64xf32> +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_cli_licm.pto b/test/lit/vmi/vmi_ptoas_cli_licm.pto new file mode 100644 index 0000000000..b6f7261bce --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_licm.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_licm( + %src: !pto.ptr, + %dst: !pto.ptr, + %count: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %i = %c0 to %count step %c1 { + %x16 = pto.vmi.load %src[%i] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%i] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + } + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_licm +// CHECK: pto.vecscope +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: scf.for +// CHECK-NOT: pto.pset_b16 +// CHECK: pto.vcvt {{.*}}, %[[MASK]] diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto new file mode 100644 index 0000000000..0a136e8c0f --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s +// RUN: ptoas --pto-arch=a5 --enable-vmi --emit-vpto %s -o - | FileCheck %s --check-prefix=ATTR +// RUN: not ptoas --pto-backend=emitc --enable-vmi %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_pipeline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_ptoas_cli_fold_pipeline( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) { + %x16 = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_pipeline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_ptoas_cli_fold_pipeline +// CHECK: pto.vlds +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} +// CHECK: pto.vcvt {{.*}} {part = "ODD"} +// CHECK: pto.vstsx2 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// ATTR-LABEL: func.func @vmi_ptoas_cli_pipeline +// ATTR: pto.vecscope +// ATTR: pto.vdup +// ATTR: pto.vsts +// ATTR-NOT: pto.vmi. +// ATTR-NOT: !pto.vmi. +// ATTR-NOT: unrealized_conversion_cast + +// EMITC: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_private_call_inline.pto b/test/lit/vmi/vmi_ptoas_private_call_inline.pto new file mode 100644 index 0000000000..c5e1604bec --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_private_call_inline.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%scalar: f32) + -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_ptoas_private_call_inline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = call @producer(%scalar) + : (f32) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-NOT: func.func private @producer +// CHECK-LABEL: func.func @vmi_ptoas_private_call_inline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: call @producer +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto new file mode 100644 index 0000000000..79b146acd8 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_abi_invalid( + %value: !pto.vmi.vreg<128xf32>) { + return + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto new file mode 100644 index 0000000000..a27067e62c --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_result_abi_invalid( + %scalar: f32) -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto new file mode 100644 index 0000000000..47dc112c04 --- /dev/null +++ b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_reduce_addf_missing_reassoc_invalid( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) { + %out = pto.vmi.reduce_addf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return + } +} + +// CHECK: 'pto.vmi.reduce_addf' op requires reassoc attr because VPTO vcadd performs pair-wise floating-point reduction diff --git a/test/lit/vmi/vmi_scatter_indices_invalid.pto b/test/lit/vmi/vmi_scatter_indices_invalid.pto new file mode 100644 index 0000000000..e16d6905f0 --- /dev/null +++ b/test/lit/vmi/vmi_scatter_indices_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_scatter_indices_invalid( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK: 'pto.vmi.scatter' op requires signless or unsigned 32-bit integer indices diff --git a/test/lit/vmi/vmi_select_mask_granularity_invalid.pto b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto new file mode 100644 index 0000000000..2e6b9d10f9 --- /dev/null +++ b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_select_mask_granularity_invalid( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_shli_float_invalid.pto b/test/lit/vmi/vmi_shli_float_invalid.pto new file mode 100644 index 0000000000..e73ee9c232 --- /dev/null +++ b/test/lit/vmi/vmi_shli_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_shrui_float_invalid.pto b/test/lit/vmi/vmi_shrui_float_invalid.pto new file mode 100644 index 0000000000..5de50dfff1 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_shrui_signed_invalid.pto b/test/lit/vmi/vmi_shrui_signed_invalid.pto new file mode 100644 index 0000000000..c3c57a52e9 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_signed_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_signed_invalid( + %lhs: !pto.vmi.vreg<128xsi16>, + %rhs: !pto.vmi.vreg<128xsi16>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xsi16>, !pto.vmi.vreg<128xsi16> + -> !pto.vmi.vreg<128xsi16> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_shuffle_indices_invalid.pto b/test/lit/vmi/vmi_shuffle_indices_invalid.pto new file mode 100644 index 0000000000..fe6582ee86 --- /dev/null +++ b/test/lit/vmi/vmi_shuffle_indices_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_shuffle_index_count_invalid(%src: !pto.vmi.vreg<8xui8>) { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<8xui8>) -> !pto.vmi.vreg<4xui8> + return + } +} + +// CHECK: 'pto.vmi.shuffle' op requires shuffle index count to match result logical lane count + +// ----- + +module { + func.func @vmi_shuffle_index_oob_invalid(%src: !pto.vmi.vreg<8xui8>) { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<8xui8>) -> !pto.vmi.vreg<4xui8> + return + } +} + +// CHECK: 'pto.vmi.shuffle' op requires every shuffle index to select an existing source logical lane diff --git a/test/lit/vmi/vmi_to_vpto_abs.pto b/test/lit/vmi/vmi_to_vpto_abs.pto new file mode 100644 index 0000000000..247a239f66 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_abs.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %abs : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[AF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[AF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[AF0]], %[[AF1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[I1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AI0:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AI1:.*]] = pto.vabs %[[I1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AI0]], %[[AI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto new file mode 100644 index 0000000000..7d64e0ec0f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<64xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_active_prefix_index( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[M:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[M]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg0 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[IDX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto new file mode 100644 index 0000000000..cb655b0e4f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_multichunk_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: multi-chunk prefix needs cross-chunk carry diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto new file mode 100644 index 0000000000..07fd5307e0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_tail_invalid( + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: padding mask lanes cannot affect the observable prefix diff --git a/test/lit/vmi/vmi_to_vpto_add.pto b/test/lit/vmi/vmi_to_vpto_add.pto new file mode 100644 index 0000000000..49b5fdeca3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_add.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_addf( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_addi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_addf( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_addi( +// CHECK: %[[IM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM0]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM1]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-NOT: pto.vmi.add +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bf16_arith.pto b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto new file mode 100644 index 0000000000..c7357b5abd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bf16_arith( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>) { + %sum = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %sum_part = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %min_part = "pto.vmi.unpack"(%min) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %max_part = "pto.vmi.unpack"(%max) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %sum_part, %min_part, %max_part + : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bf16_arith( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[ADD:.*]] = pto.vadd %arg0, %arg1, %[[MASK]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MIN:.*]] = pto.vmin %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MAX:.*]] = pto.vmax %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[ADD]], %[[MIN]], %[[MAX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast.pto b/test/lit/vmi/vmi_to_vpto_bitcast.pto new file mode 100644 index 0000000000..f73ffbe68a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_f32_to_i16( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<256xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<256xi16> + return %cast : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_f32_to_i16( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto new file mode 100644 index 0000000000..fa1a5524dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_deint_tail( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<129xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> + return %cast : !pto.vmi.vreg<129xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V2:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B2:.*]] = pto.vbitcast %[[V2]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: return %[[B0]], %[[B1]], %[[B2]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto new file mode 100644 index 0000000000..2d7b904af1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_footprint_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts +// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints +// CHECK-SAME: requires matching logical bit footprint in every physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto new file mode 100644 index 0000000000..e9ccc14ac2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_group_slots( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_group_slots( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> !pto.vreg<64xi32> +// CHECK: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: return %[[B0]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto new file mode 100644 index 0000000000..e2a1b3c789 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_partial( + %value: !pto.vmi.vreg<65xf32>) -> !pto.vmi.vreg<130xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<65xf32> -> !pto.vmi.vreg<130xi16> + return %cast : !pto.vmi.vreg<130xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_partial( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[S0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[S1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitwise.pto b/test/lit/vmi/vmi_to_vpto_bitwise.pto new file mode 100644 index 0000000000..80a665ccd9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitwise.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitwise( + %a: !pto.vmi.vreg<256xi16>, + %b: !pto.vmi.vreg<256xi16>) + -> (!pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>) { + %and = pto.vmi.andi %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %or = pto.vmi.ori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %xor = pto.vmi.xori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %not = pto.vmi.not %a + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %and, %or, %xor, %not + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitwise( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AND0:.*]] = pto.vand %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AND1:.*]] = pto.vand %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR0:.*]] = pto.vor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR1:.*]] = pto.vor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR0:.*]] = pto.vxor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR1:.*]] = pto.vxor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT0:.*]] = pto.vnot %[[A0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT1:.*]] = pto.vnot %[[A1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_broadcast.pto b/test/lit/vmi/vmi_to_vpto_broadcast.pto new file mode 100644 index 0000000000..9cdbf92e1e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_broadcast.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_broadcast_contiguous(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_deint4(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_rank0( + %scalar: !pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : !pto.vmi.vreg<1xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_contiguous( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %arg0, %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %arg0, %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_deint4( +// CHECK-COUNT-4: pto.vdup %arg0 +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_rank0( +// CHECK-COUNT-4: pto.vdup %arg0{{.*}}{position = "LOWEST"} +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_call_boundary.pto b/test/lit/vmi/vmi_to_vpto_call_boundary.pto new file mode 100644 index 0000000000..0a34ebe197 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_call_boundary.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: %[[C0:[^:]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[C1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[CM0:.*]] = pto.vadd %[[C0]], %[[C0]] +// CHECK-DAG: %[[CM1:.*]] = pto.vadd %[[C1]], %[[C1]] +// CHECK: return %[[CM0]], %[[CM1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @caller( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-DAG: %[[EA0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[EA1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[R:.*]]:2 = call @callee(%[[EA0]], %[[EA1]]) +// CHECK-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[S0:.*]] = pto.vadd %[[R]]#0, %[[R]]#0 +// CHECK-DAG: %[[S1:.*]] = pto.vadd %[[R]]#1, %[[R]]#1 +// CHECK: return %[[S0]], %[[S1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cf_branch.pto b/test/lit/vmi/vmi_to_vpto_cf_branch.pto new file mode 100644 index 0000000000..0a4cf70e1d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cf_branch.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + cf.br ^join(%then_value : !pto.vmi.vreg<128xf32>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_sum = pto.vmi.addf %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + cf.br ^join(%else_sum : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + %sum = pto.vmi.addf %value, %value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_cf_cond_branch_operands( + %cond: i1, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^join(%a : !pto.vmi.vreg<128xf32>), + ^join(%b : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cf_branch( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: cf.cond_br %[[COND]], ^[[THEN:.*]], ^[[ELSE:.*]] +// CHECK: ^[[THEN]]: +// CHECK-DAG: %[[THEN_P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[THEN_P1:.*]] = pto.vcvt %[[A]] +// CHECK: cf.br ^[[JOIN:.*]](%[[THEN_P0]], %[[THEN_P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[ELSE]]: +// CHECK-DAG: %[[ELSE_P0:.*]] = pto.vcvt %[[B]] +// CHECK-DAG: %[[ELSE_P1:.*]] = pto.vcvt %[[B]] +// CHECK: cf.br ^[[JOIN]]({{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: pto.vadd +// CHECK-LABEL: func.func @vmi_to_vpto_cf_cond_branch_operands( +// CHECK-SAME: %[[COND2:[^,]+]]: i1 +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: cf.cond_br %[[COND2]], ^[[CB_JOIN:.*]](%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), ^[[CB_JOIN]](%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[CB_JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto new file mode 100644 index 0000000000..4ffb8e384d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge4_contiguous( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch3: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2, %ch3) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_channel_merge4_contiguous( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[E0:.*]], %[[E1:.*]] = pto.vintlv %[[P0]], %[[P2]] +// CHECK: %[[O0:.*]], %[[O1:.*]] = pto.vintlv %[[P1]], %[[P3]] +// CHECK: %[[L0:.*]], %[[L1:.*]] = pto.vintlv %[[E0]], %[[O0]] +// CHECK: %[[H0:.*]], %[[H1:.*]] = pto.vintlv %[[E1]], %[[O1]] +// CHECK: return %[[L0]], %[[L1]], %[[H0]], %[[H1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..8bdc2beb6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto new file mode 100644 index 0000000000..867cbdce65 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_layout_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.channel_merge' op requires layout-assigned channel_merge inputs to be contiguous diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto new file mode 100644 index 0000000000..443a5fedae --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_partial_group_invalid( + %ch0: !pto.vmi.vreg<2xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<2xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge requires every input layout to be contiguous +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..1bc963d400 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto new file mode 100644 index 0000000000..55c9ea862e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_layout_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: 'pto.vmi.channel_split' op requires layout-assigned channel_split source to be contiguous or deinterleaved by result count diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto new file mode 100644 index 0000000000..10d90d2869 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32> + return %merged : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_channel_split4( + %src: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2, %ch3 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1, %ch2, %ch3 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_split_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_merge_deint2_identity( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// CHECK: return %[[CH0]], %[[CH1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_channel_split4( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vdintlv %[[S2]], %[[S3]] +// CHECK: %[[C0:.*]], %[[C2:.*]] = pto.vdintlv %[[A0]], %[[B0]] +// CHECK: %[[C1:.*]], %[[C3:.*]] = pto.vdintlv %[[A1]], %[[B1]] +// CHECK: return %[[C0]], %[[C1]], %[[C2]], %[[C3]] +// CHECK-LABEL: func.func @vmi_channel_split_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vdintlv +// CHECK: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_channel_merge_deint2_identity( +// CHECK-SAME: %[[M0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vintlv +// CHECK: return %[[M0]], %[[M1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto new file mode 100644 index 0000000000..25afa0d016 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2_tail( + %src: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2_tail( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[CH0]], %[[CH1]] +// CHECK: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto new file mode 100644 index 0000000000..f45b7cdfda --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_partial_group_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel layout +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto new file mode 100644 index 0000000000..1049cbdf2e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_chist_semantics_invalid( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %hist = pto.vmi.chist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUP{{.*}} pto.vmi.chist requires a verified CHISTv2 range semantics contract before lowering diff --git a/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto new file mode 100644 index 0000000000..100f4b7378 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %mask = pto.vmi.cmpf "lt", %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.cmpf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..8689bc8312 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmp_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "uno", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate uno cannot be lowered to pto.vcmp +// CHECK-SAME: supported predicates are eq/ne/lt/le/gt/ge, ordered FP forms diff --git a/test/lit/vmi/vmi_to_vpto_cmp_select.pto b/test/lit/vmi/vmi_to_vpto_cmp_select.pto new file mode 100644 index 0000000000..816913c8b2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_select.pto @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_select( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.vmi.cmpf "lt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %selected = pto.vmi.select %mask, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + %p0, %p1 = "pto.vmi.unpack"(%selected) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %m0, %m1, %p0, %p1 + : !pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_cmpi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ge", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_ordered_predicate( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpi_signed_predicate( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "slt", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_bf16( + %a: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpf "oge", %a, %b + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } + + func.func @vmi_to_vpto_cmpi_ui16( + %a: !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpi "eq", %a, %b + : !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_select( +// CHECK: %[[FM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM0:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM0]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: %[[FM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM1:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM1]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_ordered_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_signed_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_bf16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_ui16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "eq" +// CHECK-SAME: !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..23b1e7f88f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ult", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate ult cannot be lowered to pto.vcmp +// CHECK-SAME: signed integer forms slt/sle/sgt/sge diff --git a/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto new file mode 100644 index 0000000000..b4b2af9879 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_deint_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: requires contiguous mask and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %source, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous source, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_store_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous value and mask layouts diff --git a/test/lit/vmi/vmi_to_vpto_compress.pto b/test/lit/vmi/vmi_to_vpto_compress.pto new file mode 100644 index 0000000000..aba4da0228 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress( +// CHECK: %[[OUT:.*]] = pto.vsqz %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto new file mode 100644 index 0000000000..3122bbb0ee --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_multichunk_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_store.pto b/test/lit/vmi/vmi_to_vpto_compress_store.pto new file mode 100644 index 0000000000..edf8565c5f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress_store( +// CHECK: %[[BASE:.*]] = pto.addptr %arg1, %arg2 +// CHECK: %[[SQZ:.*]] = pto.vsqz %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ALIGN0:.*]] = pto.init_align : !pto.align +// CHECK: %[[ALIGN1:.*]] = pto.vstur %[[ALIGN0]], %[[SQZ]], %[[BASE]], "POST_UPDATE" : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +// CHECK: pto.vstar %[[ALIGN1]], %[[BASE]] : !pto.align, !pto.ptr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto new file mode 100644 index 0000000000..e4fc4738cc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store_multichunk_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress_store needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto new file mode 100644 index 0000000000..4d97cf831d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_tail_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: padding mask lanes cannot be squeezed into the result diff --git a/test/lit/vmi/vmi_to_vpto_constant.pto b/test/lit/vmi/vmi_to_vpto_constant.pto new file mode 100644 index 0000000000..c5c93bf2db --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_splat() + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_splat +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %[[CST]], %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %[[CST]], %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask.pto b/test/lit/vmi/vmi_to_vpto_constant_mask.pto new file mode 100644 index 0000000000..9c38a62148 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask.pto @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b8_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<512xi1> + } : () -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b16_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<256xi1> + } : () -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_plt_fallback() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, true, true, true, true, false, false, false]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_deinterleaved() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false, false, true, false, true]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_true +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_false +// CHECK: %[[F0:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: %[[F1:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: return %[[F0]], %[[F1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b8_all_true +// CHECK: %[[B8_0:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b16_all_false +// CHECK: %[[B16_0:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_plt_fallback +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[P0:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_deinterleaved +// CHECK: %[[PART0:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P4:.*]] = pto.pset_b32 "PAT_VL4" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[PART1:.*]] = pto.pand %[[P4]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[PART0]], %[[PART1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto new file mode 100644 index 0000000000..cc3f439e62 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_nonprefix() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false]> : tensor<4xi1> + } : () -> !pto.vmi.mask<4xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<4xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_nonprefix +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[RUN0:.*]] = pto.pset_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[P3:.*]] = pto.pset_b32 "PAT_VL3" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[RUN1:.*]] = pto.pand %[[P3]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[OUT:.*]] = pto.por %[[RUN0]], %[[RUN1]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto new file mode 100644 index 0000000000..55e1308b4b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_rematerialize( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_rematerialize( +// CHECK: %[[M32_0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M32_1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[S16:.*]] = pto.vsel %arg0, %arg1, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg2, %arg4, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg3, %arg5, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto new file mode 100644 index 0000000000..1d9fe8377e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_nonsplat_invalid() + -> (!pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + } : () -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %p0 : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{.*}}non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan diff --git a/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto new file mode 100644 index 0000000000..1c6fdea4b0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_f64_unsupported(%scalar: f64) { + %value = pto.vmi.broadcast %scalar + : f64 -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.broadcast direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_constant_f64_unsupported() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<32xf64> + } : () -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.constant direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_create_mask.pto b/test/lit/vmi/vmi_to_vpto_create_mask.pto new file mode 100644 index 0000000000..63417a8a99 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask.pto @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 96 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b8_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 320 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b16_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_contiguous +// CHECK: %[[M0:.*]] = pto.pge_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_deint2 +// CHECK: %[[P0:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: %[[P1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b8_contiguous +// CHECK: %[[B8_0:.*]] = pto.pge_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pge_b8 "PAT_VL64" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b16_deint2 +// CHECK: %[[B16_0:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto new file mode 100644 index 0000000000..c702d80529 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_dynamic_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint4(%active: index) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_contiguous +// CHECK: %[[ACTIVE:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG:.*]] = arith.maxsi %[[ACTIVE]], {{.*}} : i32 +// CHECK: %[[CLAMPED:.*]] = arith.minui %[[NONNEG]], {{.*}} : i32 +// CHECK: %[[P0:.*]], %[[REM:.*]] = pto.plt_b32 %[[CLAMPED]] : i32 -> !pto.mask, i32 +// CHECK: %[[P1:.*]], %{{.*}} = pto.plt_b32 %[[REM]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint2 +// CHECK: %[[ACTIVE2:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG2:.*]] = arith.maxsi %[[ACTIVE2]], {{.*}} : i32 +// CHECK: %[[CLAMPED2:.*]] = arith.minui %[[NONNEG2]], {{.*}} : i32 +// CHECK: %[[BIAS2:.*]] = arith.addi %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[PART0:.*]] = arith.divui %[[BIAS2]], {{.*}} : i32 +// CHECK: %[[P2_0:.*]], %{{.*}} = pto.plt_b32 %[[PART0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART1:.*]] = arith.divui %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[P2_1:.*]], %{{.*}} = pto.plt_b32 %[[PART1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P2_0]], %[[P2_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint4 +// CHECK: %[[ACTIVE4:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG4:.*]] = arith.maxsi %[[ACTIVE4]], {{.*}} : i32 +// CHECK: %[[CLAMPED4:.*]] = arith.minui %[[NONNEG4]], {{.*}} : i32 +// CHECK: %[[BIAS4_0:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_0:.*]] = arith.divui %[[BIAS4_0]], {{.*}} : i32 +// CHECK: %[[P4_0:.*]], %{{.*}} = pto.plt_b32 %[[PART4_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_1:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_1:.*]] = arith.divui %[[BIAS4_1]], {{.*}} : i32 +// CHECK: %[[P4_1:.*]], %{{.*}} = pto.plt_b32 %[[PART4_1]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_2:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_2:.*]] = arith.divui %[[BIAS4_2]], {{.*}} : i32 +// CHECK: %[[P4_2:.*]], %{{.*}} = pto.plt_b32 %[[PART4_2]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART4_3:.*]] = arith.divui %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[P4_3:.*]], %{{.*}} = pto.plt_b32 %[[PART4_3]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P4_0]], %[[P4_1]], %[[P4_2]], %[[P4_3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous +// CHECK: %[[ACTIVE8:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG8:.*]] = arith.maxsi %[[ACTIVE8]], {{.*}} : i32 +// CHECK: %[[CLAMPED8:.*]] = arith.minui %[[NONNEG8]], {{.*}} : i32 +// CHECK: %[[P8_0:.*]], %[[REM8:.*]] = pto.plt_b8 %[[CLAMPED8]] : i32 -> !pto.mask, i32 +// CHECK: %[[P8_1:.*]], %{{.*}} = pto.plt_b8 %[[REM8]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P8_0]], %[[P8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMPED16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[BIAS16:.*]] = arith.addi %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[PART16_0:.*]] = arith.divui %[[BIAS16]], {{.*}} : i32 +// CHECK: %[[P16_0:.*]], %{{.*}} = pto.plt_b16 %[[PART16_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART16_1:.*]] = arith.divui %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[P16_1:.*]], %{{.*}} = pto.plt_b16 %[[PART16_1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto new file mode 100644 index 0000000000..8cd9cd051c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_plt_fallback() + -> !pto.mask { + %active = arith.constant 5 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_plt_fallback( +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[MASK]] : !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto new file mode 100644 index 0000000000..03add9ada4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_rematerialize( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_rematerialize( +// CHECK: %[[ACTIVE32:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG32:.*]] = arith.maxsi %[[ACTIVE32]], {{.*}} : i32 +// CHECK: %[[CLAMP32:.*]] = arith.minui %[[NONNEG32]], {{.*}} : i32 +// CHECK: %[[M32_0:.*]], %[[REM32:.*]] = pto.plt_b32 %[[CLAMP32]] : i32 -> !pto.mask, i32 +// CHECK: %[[M32_1:.*]], %{{.*}} = pto.plt_b32 %[[REM32]] : i32 -> !pto.mask, i32 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMP16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[M16:.*]], %{{.*}} = pto.plt_b16 %[[CLAMP16]] : i32 -> !pto.mask, i32 +// CHECK: %[[S16:.*]] = pto.vsel %arg1, %arg2, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg3, %arg5, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg4, %arg6, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_dhist.pto b/test/lit/vmi/vmi_to_vpto_dhist.pto new file mode 100644 index 0000000000..b8a1113534 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LO:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC]], %[[MASK]], %[[BIN0]] +// CHECK: %[[HI:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC]], %[[MASK]], %[[BIN1]] +// CHECK: return %[[LO]], %[[HI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto new file mode 100644 index 0000000000..4aada7a188 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist_tail_mask( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<300xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + !pto.vmi.mask<300xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist_tail_mask( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC0:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[SRC1:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[MASK1:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C44:.*]] = arith.constant 44 : i32 +// CHECK: %[[LO0:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC0]], %[[MASK0]], %[[BIN0]] +// CHECK: %[[HI0:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC0]], %[[MASK0]], %[[BIN1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b8 %[[C44]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK1_VALID:.*]] = pto.pand %[[MASK1]], %[[TAIL]], %[[ALL]] +// CHECK: %[[LO1:.*]] = pto.dhistv2 %[[LO0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN0]] +// CHECK: %[[HI1:.*]] = pto.dhistv2 %[[HI0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN1]] +// CHECK: return %[[LO1]], %[[HI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_divf.pto b/test/lit/vmi/vmi_to_vpto_divf.pto new file mode 100644 index 0000000000..be21ba5fdc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_divf.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_divf( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %quotient : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_divf( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[DIV0:.*]] = pto.vdiv %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[DIV1:.*]] = pto.vdiv %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DIV0]], %[[DIV1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto new file mode 100644 index 0000000000..f88f15a8eb --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_f16_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_f8_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_f16_widen_add_store( +// CHECK: %[[NARROW:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: %[[CVT_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "EVEN"} +// CHECK: %[[ODD:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "ODD"} +// CHECK: %[[ADD_MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]], %[[ADD_MASK0]] +// CHECK: %[[ADD_MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]], %[[ADD_MASK1]] +// CHECK: %[[STORE_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vstsx2 %[[SUM0]], %[[SUM1]], %arg1[%arg2], "INTLV_B32", %[[STORE_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_f8_widen_add_store( +// CHECK: %[[NARROW8:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P0"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P1"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P2"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P3"} +// CHECK-COUNT-4: pto.vadd +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK-COUNT-4: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto new file mode 100644 index 0000000000..958e6f1f5a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_f64_unsupported( + %a: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %b: !pto.vmi.vreg<32xf64, #pto.vmi.layout>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + !pto.vmi.vreg<32xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type and physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_addi_index_unsupported( + %a: !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + %b: !pto.vmi.vreg<64xindex, #pto.vmi.layout>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + !pto.vmi.vreg<64xindex, #pto.vmi.layout> + -> !pto.vmi.vreg<64xindex, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addi direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_ensure_identity.pto b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto new file mode 100644 index 0000000000..783bc3428d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_identity( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %em0 = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %em1 = "pto.vmi.ensure_mask_granularity"(%em0) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em1) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_ensure_identity_tail( + %v: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %em = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity_tail( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto new file mode 100644 index 0000000000..cd78e684c7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_deint2_to_deint4( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_to_deint4( +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto new file mode 100644 index 0000000000..bbfd4dcfd8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_partial_invalid( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_layout cannot materialize the requested data layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto new file mode 100644 index 0000000000..03661ac669 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto new file mode 100644 index 0000000000..e4506c86c2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( + %input: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto new file mode 100644 index 0000000000..989fd0cb74 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity( +// CHECK: %[[LO:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO]], %[[HI]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsel %arg2, %arg3, %[[M16]] +// CHECK: pto.vsel %arg4, %arg6, %arg0 +// CHECK: pto.vsel %arg5, %arg7, %arg1 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto new file mode 100644 index 0000000000..2512367b64 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_direct( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_direct( +// CHECK: %[[P0:.*]] = pto.punpack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[P1:.*]] = pto.punpack %arg0, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto new file mode 100644 index 0000000000..29bb147489 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_multistep( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.mask { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb8, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb8, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_multistep( +// CHECK: %[[LO16:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI16:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO16]], %[[HI16]], %[[ALL16]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[M8:.*]] = pto.ppack %[[M16]], "LOWER" : !pto.mask -> !pto.mask +// CHECK: return %[[M8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto new file mode 100644 index 0000000000..17a644834b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_deint2_to_contiguous( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint2( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<100xb32, #pto.vmi.layout> + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint4_to_contiguous( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint4( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.pintlv_b32 %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.pintlv_b32 %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b32 %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.pdintlv_b32 %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.pdintlv_b32 %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.pdintlv_b32 %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto new file mode 100644 index 0000000000..87edcee933 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_layout_partial_invalid( + %input: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_mask_layout %input + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail predicate layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto new file mode 100644 index 0000000000..0c8b9a4120 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b8 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b8 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b16 %arg0, %arg1 +// CHECK: return %[[D2]], %[[D3]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( +// CHECK: %[[P2:.*]], %[[P3:.*]] = pto.pdintlv_b16 %arg0, %arg1 +// CHECK: return %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto new file mode 100644 index 0000000000..836e33dec0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[LOAD]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto new file mode 100644 index 0000000000..d8733b4641 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active_negative_offset_invalid( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant -1 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: all-active path requires full physical chunks or statically safe full-read footprint +// CHECK-SAME: safe-read proof requires non-negative offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto new file mode 100644 index 0000000000..cdab169262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_partial_mask_invalid( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: one physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto new file mode 100644 index 0000000000..7c9d8a3a5b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_runtime_mask( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_runtime_mask( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK-DAG: %[[BASE:.*]] = pto.addptr %arg0, %arg1 +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[ALL]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[LOAD:.*]] = pto.vgather2_bc %[[BASE]], %[[IDX]], %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf.pto b/test/lit/vmi/vmi_to_vpto_extf.pto new file mode 100644 index 0000000000..af4fbca903 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f16_to_f32( + %input: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f16_tail_to_f32( + %input: !pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_bf16_to_f32( + %input: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_bf16_to_f32( +// CHECK-SAME: %[[BF16_INPUT:.*]]: !pto.vreg<128xbf16> +// CHECK: %[[BF16_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "EVEN"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "ODD"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_f8.pto b/test/lit/vmi/vmi_to_vpto_extf_f8.pto new file mode 100644 index 0000000000..c9ab157d0d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_f8.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f8_to_f32( + %input: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f8_tail_to_f32( + %input: !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto new file mode 100644 index 0000000000..0803ccde1c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_multichunk( + %input: !pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_multichunk( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[EVEN1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[EVEN0]], %[[EVEN1]], %[[ODD0]], %[[ODD1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma.pto b/test/lit/vmi/vmi_to_vpto_fma.pto new file mode 100644 index 0000000000..d222c1ecb9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_fma( + %lhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_fma_f16( + %lhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_fma_bf16( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %part : !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_fma( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_f16( +// CHECK: %[[MASK16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK16]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_bf16( +// CHECK: %[[MASKBF16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUTBF16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASKBF16]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[OUTBF16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto new file mode 100644 index 0000000000..877568258b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_fma_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 element types +// CHECK-SAME: requires f16, bf16, or f32 element type for pto.vmula diff --git a/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto new file mode 100644 index 0000000000..0fedf0d694 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> +} + +// CHECK: VMI-PASS-INVARIANT: vmi-to-vpto requires layout-assigned VMI types diff --git a/test/lit/vmi/vmi_to_vpto_gather.pto b/test/lit/vmi/vmi_to_vpto_gather.pto new file mode 100644 index 0000000000..d68e72c1d2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_gather( +// CHECK: %[[GATHER:.*]] = pto.vgather2_bc %arg0, %arg1, %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[GATHER]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto new file mode 100644 index 0000000000..21c8753ec7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_f16_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: 32-bit result elements diff --git a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto new file mode 100644 index 0000000000..2e5afb7708 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_deint_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: contiguous result, indices, passthru, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_gather_tail_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<32xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout>, + !pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: result requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_scatter_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: contiguous value, indices, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_scatter_tail_invalid( + %value: !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: value requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_gather_u16.pto b/test/lit/vmi/vmi_to_vpto_gather_u16.pto new file mode 100644 index 0000000000..bcf0caede3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_u16.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_u16( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<32xui16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb16, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<32xui16, #pto.vmi.layout>) + -> !pto.vreg<128xui16> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<32xui16, #pto.vmi.layout>, + !pto.vmi.mask<32xb16, #pto.vmi.layout>, + !pto.vmi.vreg<32xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<32xui16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<32xui16, #pto.vmi.layout>) + -> !pto.vreg<128xui16> + return %part : !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_gather_u16( +// CHECK: %[[GATHER:.*]] = pto.vgather2 %arg0, %arg1, %arg2 : !pto.ptr, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[OUT:.*]] = pto.vsel %[[GATHER]], %arg3, %arg2 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto new file mode 100644 index 0000000000..7f9e02f144 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_deint( + %sum: !pto.vmi.vreg<2xf32, #pto.vmi.layout>, + %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) + -> !pto.vmi.vreg<512xf32> { + %src_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<2xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_deint( +// CHECK-COUNT-8: {position = "LOWEST"} +// CHECK-COUNT-8: pto.vmul +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..41c2304d23 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + return %out : !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + } + + func.func @vmi_to_vpto_group_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[E2B:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B]], %[[E2B]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC32:.*]]: !pto.ptr, %[[OFF32:.*]]: index) +// CHECK: %[[E2B32:.*]] = pto.vlds %[[SRC32]][%[[OFF32]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B32]] : !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto new file mode 100644 index 0000000000..02640d136e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> { + %out = pto.vmi.group_broadcast_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + return %out : !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_broadcast_load requires either the E2B packet form +// CHECK: fallback lowering requires constant unit source_group_stride diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto new file mode 100644 index 0000000000..7b19a1254a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_s32_deint2_small_group( + %source: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %broadcast = pto.vmi.group_broadcast %source {num_groups = 4} + : !pto.vmi.vreg<4xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%broadcast) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_s32_deint2_small_group( +// CHECK-COUNT-2: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto new file mode 100644 index 0000000000..6f23678fea --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source + {num_groups = 128} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto new file mode 100644 index 0000000000..c8268869a3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8_support( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_support( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto new file mode 100644 index 0000000000..3b2452545c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_vselr( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_vselr( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_load_support.pto b/test/lit/vmi/vmi_to_vpto_group_load_support.pto new file mode 100644 index 0000000000..1af77958af --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_load_support.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_load_support( + %source: !pto.ptr, + %row_stride: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_load_support( +// CHECK-COUNT-8: pto.vlds +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto new file mode 100644 index 0000000000..5abb0de3ab --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_ops( + %src: !pto.ptr, + %dst: !pto.ptr, + %row_stride: index, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %v = pto.vmi.group_load %src[%c0], %row_stride + {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> + pto.vmi.group_store %r, %dst[%c0], %row_stride {num_groups = 2} + : !pto.vmi.vreg<2xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_ops( +// CHECK-COUNT-8: pto.vlds +// CHECK: pto.vcgadd +// CHECK: pto.vselr +// CHECK-COUNT-7: pto.vcgadd +// CHECK-COUNT-2: pto.vsts {{.*}} {dist = "1PT_B32"} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto new file mode 100644 index 0000000000..1287a859c6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @legacy_group_slots_without_explicit_slots( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + // CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..d2886fa5aa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 4} + : !pto.vmi.vreg<4xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g8( + %source: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + !pto.vmi.mask<512xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<768xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + !pto.vmi.mask<768xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 12} + : !pto.vmi.vreg<12xf16, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g4( +// CHECK-DAG: %[[SLOT4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4]] +// CHECK: %[[STORE4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g8( +// CHECK-DAG: %[[SLOT8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8]] +// CHECK: %[[STORE8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g12( +// CHECK: %[[SLOT8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8_12]] +// CHECK: %[[SLOT4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4_12]] +// CHECK: %[[STORE8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8_12]] +// CHECK: %[[STORE4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4_12]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto new file mode 100644 index 0000000000..cb790d5f91 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s256_broadcast( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<2xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%broadcast) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s256_broadcast( +// CHECK: pto.vcgadd +// CHECK: pto.vadd +// CHECK: pto.vsel +// CHECK: pto.vdup {{.*}} {position = "LOWEST"} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto new file mode 100644 index 0000000000..83b7ee9ae7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK: pto.vcgadd +// CHECK: pto.vadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: pto.vcgadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto new file mode 100644 index 0000000000..aba242c03f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64_support( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_support( +// CHECK-COUNT-8: pto.vcgadd +// CHECK: pto.vsel +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto new file mode 100644 index 0000000000..c8d5a85757 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto new file mode 100644 index 0000000000..d88c1cf1ad --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8_support( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_support( +// CHECK: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto new file mode 100644 index 0000000000..88a8598a82 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_group_reduce_addf_f16_vlane( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<128xi16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %out = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } + + func.func @vmi_group_reduce_addi_i32_two_vlane( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_group_reduce_addf_f16_vlane( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( +// CHECK: %[[EVEN:.*]] = pto.vcvt %arg0, {{.*}} {part = "EVEN"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ODD:.*]] = pto.vcvt %arg0, {{.*}} {part = "ODD"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[EVEN]], %arg1 +// CHECK: %[[S1:.*]] = pto.vcgadd %[[ODD]], %arg2 +// CHECK: %[[SUM:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: return %[[SUM]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i32_two_vlane( +// CHECK: %[[MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: %[[SLO:.*]] = pto.vcgadd %arg0, %arg2 +// CHECK: %[[SHI:.*]] = pto.vcgadd %arg1, %arg3 +// CHECK: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[MASK]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto new file mode 100644 index 0000000000..c9e77a1b67 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto new file mode 100644 index 0000000000..580ab5f4e5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<1024xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + !pto.vmi.mask<1024xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( +// CHECK-COUNT-16: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto new file mode 100644 index 0000000000..9a2d80feb1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots1( + %src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_load_u8_scale_broadcast( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %c23_i32 = arith.constant 23 : i32 + %scale_u8 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + %bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + %scale = pto.vmi.bitcast %bits + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %vec, %dst[%off] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_load_u16_broadcast( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %scale_u16 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xui16, #pto.vmi.layout> + %scale_u32 = pto.vmi.extui %scale_u16 + : !pto.vmi.vreg<8xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> + %vec = pto.vmi.group_broadcast %scale_u32 {num_groups = 8} + : !pto.vmi.vreg<8xui32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + pto.vmi.store %vec, %dst[%off] + : !pto.vmi.vreg<256xui32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8( +// CHECK: %[[MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[MASK]] : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots1( +// CHECK-COUNT-8: pto.vsldb + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8_store( +// CHECK: %[[LOAD_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg2 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[LOAD_MASK]] +// CHECK: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts %[[OUT]], %arg1[%arg2], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_u8_scale_broadcast( +// CHECK: pto.pge_b8 "PAT_VL8" : !pto.mask +// CHECK: pto.vsldb {{.*}} : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<256xui8> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P0"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P1"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P2"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P3"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vselr {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xi32> -> !pto.vreg<64xui32> +// CHECK: pto.vsel {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vshl +// CHECK: pto.vselr +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_u16_broadcast( +// CHECK: pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsldb {{.*}} : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<128xui16> +// CHECK-DAG: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vselr {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xi32> -> !pto.vreg<64xui32> +// CHECK: pto.vsel {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xui32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto new file mode 100644 index 0000000000..e16ca46fa3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_nonunit_slots8_invalid( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> { + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout +// CHECK: slots=8 group_slot_load requires constant unit source_group_stride diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto new file mode 100644 index 0000000000..44ea0d5e54 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_support( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_support( +// CHECK: pto.pge_b32 "PAT_VL8" +// CHECK: pto.vsldb +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto new file mode 100644 index 0000000000..3b87c4c684 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %source + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK-COUNT-8: pto.vcvt {{.*}}, %[[VL1]] {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto new file mode 100644 index 0000000000..5b44773a33 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1_support( + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %source + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_support( +// CHECK-COUNT-8: pto.vcvt +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto new file mode 100644 index 0000000000..3dc813a4f2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots1_1pt( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c2 = arith.constant 2 : index + pto.vmi.group_store %value, %dst[%off], %c2 {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots1_1pt( +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto new file mode 100644 index 0000000000..d5a55bd75e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @aligned_unit_stride_group_store( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + pto.vmi.group_store %value, %dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @unaligned_unit_stride_group_store( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %row: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %off = arith.muli %row, %c2 : index + pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @aligned_unit_stride_group_store( +// CHECK-COUNT-8: pto.vdup +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: dist = "1PT_B32" + +// CHECK-LABEL: func.func @unaligned_unit_stride_group_store( +// CHECK-NOT: pto.vdup +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto new file mode 100644 index 0000000000..c09de38e51 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots8_nonunit_invalid( + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index, + %row_stride: index) { + pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK-SAME: pto.vmi.group_store +// CHECK-SAME: slots=8 group_store currently requires constant unit row_stride diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto new file mode 100644 index 0000000000..4f0055bb4f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots8_i32_to_u8( + %value: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 32} + : !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_store_slots8_i32_to_u8_padded( + %value: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + %dst: !pto.ptr) { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + pto.vmi.group_store %value, %dst[%c32], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots8_i32_to_u8( +// CHECK-COUNT-1: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots8_i32_to_u8_padded( +// CHECK: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +// CHECK: pto.vpack {{.*}} "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> +// CHECK: pto.vsts {{.*}} {dist = "NORM_B8"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto new file mode 100644 index 0000000000..4f844e469a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_extsi_i8_to_i32_group_reduce( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<256xi8> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %sum = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_extsi_i8_to_i32_group_reduce( +// CHECK: %[[P0:.*]] = pto.vcvt %arg0, {{.*}} {part = "P0"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P1:.*]] = pto.vcvt %arg0, {{.*}} {part = "P1"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P2:.*]] = pto.vcvt %arg0, {{.*}} {part = "P2"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P3:.*]] = pto.vcvt %arg0, {{.*}} {part = "P3"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[P0]] +// CHECK: %[[S1:.*]] = pto.vcgadd %[[P1]] +// CHECK: %[[S2:.*]] = pto.vcgadd %[[P2]] +// CHECK: %[[S3:.*]] = pto.vcgadd %[[P3]] +// CHECK: %[[A01:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: %[[A23:.*]] = pto.vadd %[[S2]], %[[S3]] +// CHECK: %[[SUM:.*]] = pto.vadd %[[A01]], %[[A23]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto new file mode 100644 index 0000000000..d65b028c70 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -0,0 +1,201 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extui_u8_to_u32( + %input: !pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) { + %wide = pto.vmi.extui %input + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xui32, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32> + } + + func.func @vmi_to_vpto_extui_u8_to_u16( + %input: !pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %wide = pto.vmi.extui %input + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %even, %odd = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %even, %odd : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } + + func.func @vmi_to_vpto_trunci_i32_to_ui8( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xui8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> !pto.vreg<256xui8> + return %p : !pto.vreg<256xui8> + } + + func.func @vmi_to_vpto_trunci_i32_d4_to_ui16_d2( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %low, %high = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %low, %high : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } + + func.func @vmi_to_vpto_fptosi_f32_to_i32( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %wide = pto.vmi.fptosi %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_group_slot_trunci_i32_to_ui8( + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8( + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_lane_stride( + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xui8> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u16( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xui8> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "EVEN"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "ODD"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_to_ui8( +// CHECK: %[[P0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P1", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P2:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P2", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P3:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P3", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[M01:.*]] = pto.vor %[[P0]], %[[P1]] +// CHECK: %[[M012:.*]] = pto.vor %[[M01]], %[[P2]] +// CHECK: pto.vor %[[M012]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_d4_to_ui16_d2( +// CHECK-SAME: %[[P0:.*]]: !pto.vreg<64xi32>, %[[P1:.*]]: !pto.vreg<64xi32>, %[[P2:.*]]: !pto.vreg<64xi32>, %[[P3:.*]]: !pto.vreg<64xi32> +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[R0_EVEN:.*]] = pto.vcvt %[[P0]], %[[MASK]] {part = "EVEN", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[R0_ODD:.*]] = pto.vcvt %[[P1]], %[[MASK]] {part = "ODD", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vor %[[R0_EVEN]], %[[R0_ODD]] +// CHECK: %[[R1_EVEN:.*]] = pto.vcvt %[[P2]], %[[MASK]] {part = "EVEN", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[R1_ODD:.*]] = pto.vcvt %[[P3]], %[[MASK]] {part = "ODD", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vor %[[R1_EVEN]], %[[R1_ODD]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fptosi_f32_to_i32( +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_i32_to_ui8( +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8( +// CHECK: pto.pge_b32 "PAT_VL8" +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_lane_stride( +// CHECK-NOT: pto.vcvt +// CHECK-NOT: pto.vpack +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_iota.pto b/test/lit/vmi/vmi_to_vpto_iota.pto new file mode 100644 index 0000000000..a46f767b59 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota.pto @@ -0,0 +1,120 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_i32_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_desc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base {order = "DESC"} + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_deint2_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i16_asc(%base: i16) + -> !pto.vreg<128xi16> { + %value = pto.vmi.iota %base + : i16 -> !pto.vmi.vreg<128xi16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi16, #pto.vmi.layout>) + -> !pto.vreg<128xi16> + return %part : !pto.vreg<128xi16> + } + + func.func @vmi_to_vpto_iota_f16_deint2_asc(%base: f16) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.iota %base + : f16 -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_asc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_desc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.subi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_deint2_asc( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[FACTOR:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[PART1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LOCAL0:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED0:.*]] = pto.vmuls %[[LOCAL0]], %[[FACTOR]], +// CHECK: %[[P0:.*]] = pto.vadds %[[SCALED0]], %arg0, +// CHECK: %[[LOCAL1:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED1:.*]] = pto.vmuls %[[LOCAL1]], %[[FACTOR]], +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[PART1]] : i32 +// CHECK: %[[P1:.*]] = pto.vadds %[[SCALED1]], %[[BASE1]], +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i16_asc( +// CHECK: %[[P16:.*]] = pto.vci %arg0 : i16 -> !pto.vreg<128xi16> +// CHECK: return %[[P16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_f16_deint2_asc( +// CHECK-DAG: %[[ZERO16:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[FACTOR16:.*]] = arith.constant 2.000000e+00 : f16 +// CHECK-DAG: %[[PART16_1:.*]] = arith.constant 1.000000e+00 : f16 +// CHECK: %[[LOCAL16_0:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_0:.*]] = pto.vmuls %[[LOCAL16_0]], %[[FACTOR16]], +// CHECK: %[[P16_0:.*]] = pto.vadds %[[SCALED16_0]], %arg0, +// CHECK: %[[LOCAL16_1:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_1:.*]] = pto.vmuls %[[LOCAL16_1]], %[[FACTOR16]], +// CHECK: %[[BASE16_1:.*]] = arith.addf %arg0, %[[PART16_1]] : f16 +// CHECK: %[[P16_1:.*]] = pto.vadds %[[SCALED16_1]], %[[BASE16_1]], +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_iota_tail.pto b/test/lit/vmi/vmi_to_vpto_iota_tail.pto new file mode 100644 index 0000000000..7ba8a31f11 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota_tail.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_contiguous_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<100xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_deint2_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<130xi32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<130xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_contiguous_tail( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_deint2_tail( +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C129:.*]] = arith.constant 129 : i32 +// CHECK: %[[BASE128:.*]] = arith.addi %arg0, %[[C128]] : i32 +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[C1]] : i32 +// CHECK: %[[BASE129:.*]] = arith.addi %arg0, %[[C129]] : i32 +// CHECK: return {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint.pto b/test/lit/vmi/vmi_to_vpto_load_deint.pto new file mode 100644 index 0000000000..0f3c3f825a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_deint4(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4( +// CHECK: %[[E0:.*]], %[[O0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1:.*]], %[[O1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[E0]], %[[E1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[O0]], %[[O1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto new file mode 100644 index 0000000000..200a1af04e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_deint4_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 = + "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_multichunk( +// CHECK: %[[P0_0:.*]], %[[P1_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4_multichunk( +// CHECK: %[[E0_0:.*]], %[[O0_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1_0:.*]], %[[O1_0:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_0:.*]], %[[P2_0:.*]] = pto.vdintlv %[[E0_0]], %[[E1_0]] +// CHECK: %[[P1_0:.*]], %[[P3_0:.*]] = pto.vdintlv %[[O0_0]], %[[O1_0]] +// CHECK: %[[E0_1:.*]], %[[O0_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[E1_1:.*]], %[[O1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P2_1:.*]] = pto.vdintlv %[[E0_1]], %[[E1_1]] +// CHECK: %[[P1_1:.*]], %[[P3_1:.*]] = pto.vdintlv %[[O0_1]], %[[O1_1]] +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]], %[[P2_0]], %[[P2_1]], %[[P3_0]], %[[P3_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_nonfull.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull.pto new file mode 100644 index 0000000000..edb8f88cf3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_nonfull( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_nonfull( +// CHECK: pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto new file mode 100644 index 0000000000..d1e1f27c94 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_nonfull_memref(%src: memref<100xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<100xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_nonfull_memref( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: pto.vlds %arg0[%[[C0]]] : memref<100xf32> -> !pto.vreg<64xf32> +// CHECK: pto.vlds %arg0[%[[C64]]] : memref<100xf32> -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto new file mode 100644 index 0000000000..40bbe153c0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %value = pto.vmi.load %src[%c4] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref( +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto new file mode 100644 index 0000000000..b444c3d1a8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %value = pto.vmi.load %src[%cm1] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset( +// CHECK: pto.vlds +// CHECK: pto.vlds +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto new file mode 100644 index 0000000000..891cb20567 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_store_contiguous( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_store_contiguous( +// CHECK: %[[C64_LOAD:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OFF1_LOAD:.*]] = arith.addi %arg2, %[[C64_LOAD]] : index +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[OFF1_LOAD]]] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L0]], %arg1[%arg2], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_mask_logic.pto b/test/lit/vmi/vmi_to_vpto_mask_logic.pto new file mode 100644 index 0000000000..cf220cc6a4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_mask_logic.pto @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_logic( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %c: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>) { + %lt = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %gt = pto.vmi.cmpf "ogt", %a, %c + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %and = pto.vmi.mask_and %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %or = pto.vmi.mask_or %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %xor = pto.vmi.mask_xor %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %not = pto.vmi.mask_not %lt + : !pto.vmi.mask<128xpred> -> !pto.vmi.mask<128xpred> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + } + + func.func @vmi_to_vpto_mask_logic_b8( + %lhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + } + + func.func @vmi_to_vpto_mask_logic_b16( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND0:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[AND1:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR0:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR1:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR0:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR1:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT0:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT1:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b8( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B8:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B8:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B8:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B8:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B8]], %[[OR_B8]], %[[XOR_B8]], %[[NOT_B8]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b16( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B16:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B16:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B16:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B16:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B16]], %[[OR_B16]], %[[XOR_B16]], %[[NOT_B16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load.pto b/test/lit/vmi/vmi_to_vpto_masked_load.pto new file mode 100644 index 0000000000..bc46f591ac --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto new file mode 100644 index 0000000000..9b79049c1a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_nonfull_invalid( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout>, + !pto.vmi.vreg<4xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires constant index offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback +// CHECK-SAME: target true masked/non-faulting load is unavailable diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto new file mode 100644 index 0000000000..d4b9f23c23 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref( + %src: memref<128xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %out = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %out = pto.vmi.masked_load %src[%c4], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto new file mode 100644 index 0000000000..ab22618d3e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %out = pto.vmi.masked_load %src[%cm1], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires non-negative offset diff --git a/test/lit/vmi/vmi_to_vpto_masked_store.pto b/test/lit/vmi/vmi_to_vpto_masked_store.pto new file mode 100644 index 0000000000..01e8d53d89 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_contiguous( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_contiguous( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[OFF1:.*]] = arith.addi %[[OFF]], %[[C64]] : index +// CHECK: pto.vsts %[[V1]], %[[DST]][%[[OFF1]]], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto new file mode 100644 index 0000000000..e874e8d90d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// CHECK: %[[USER:.*]], %{{.*}} = pto.pintlv_b32 %[[M0]], %[[M1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK:.*]] = pto.pand %[[USER]], %[[TAIL]], %[[ALL]] +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto new file mode 100644 index 0000000000..375f44c894 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_nonfull_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<129xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<129xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto new file mode 100644 index 0000000000..361277c4fd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<100xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[COMBINED:.*]] = pto.pand %[[M1]], %[[TAIL]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[COMBINED]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto new file mode 100644 index 0000000000..1102d992ef --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_addf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_divf_bf16_invalid( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.divf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_sqrt_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.sqrt %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.sqrt direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_exp_f8_invalid( + %source: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.exp %source + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.exp direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_negf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.negf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.negf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_ln_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.ln %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ln direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.absf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absi_unsigned_invalid( + %source: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) { + %out = pto.vmi.absi %source + : !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absi direct lowering requires signless/signed i8/i16/i32 element type +// CHECK-SAME: requires signless/signed i8/i16/i32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto new file mode 100644 index 0000000000..a1749cffe1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_gm_unsupported(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_masked_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_expand_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed + +// ----- + +module { + func.func @vmi_masked_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed diff --git a/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto new file mode 100644 index 0000000000..98d92a3262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_f16( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_store_deint2_i8( + %value: !pto.vmi.vreg<512xi8, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<512xi8, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_f16( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B16" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_i8( +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B8", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto new file mode 100644 index 0000000000..ba91366ff2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32, strided<[2], offset: 0>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_load_memref_subview_unsupported(%src: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %src[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + %value = pto.vmi.load %view[%c0] + : memref<64xf32, strided<[1], offset: ?>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_expand_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.expand_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.store %value, %dst[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_memref_subview_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %dst[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + pto.vmi.store %value, %view[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<64xf32, strided<[1], offset: ?>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.masked_store %value, %dst[%c0], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps diff --git a/test/lit/vmi/vmi_to_vpto_min_max.pto b/test/lit/vmi/vmi_to_vpto_min_max.pto new file mode 100644 index 0000000000..eeefc6ee94 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_min_max.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_min_max( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %min, %max : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_min_max( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[MIN0:.*]] = pto.vmin %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MIN1:.*]] = pto.vmin %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX0:.*]] = pto.vmax %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX1:.*]] = pto.vmax %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[MIN0]], %[[MIN1]], %[[MAX0]], %[[MAX1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_negf.pto b/test/lit/vmi/vmi_to_vpto_negf.pto new file mode 100644 index 0000000000..1aafa02c9a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_negf.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_negf(%a: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %neg = pto.vmi.negf %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_negf( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[A0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[A1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_pack_unpack.pto b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto new file mode 100644 index 0000000000..e4caa3cc0a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unpack( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %p0, %p1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_pack_unpack( + %p0: !pto.vreg<64xf32>, + %p1: !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %v = "pto.vmi.pack"(%p0, %p1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %q0, %q1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %q0, %q1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unpack( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_pack_unpack( +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi.pack +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto new file mode 100644 index 0000000000..e6dad5963a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -0,0 +1,312 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_f16( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( +// CHECK-SAME: %[[DSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[DDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_f16( +// CHECK-SAME: %[[QSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[INV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( +// CHECK-SAME: %[[FSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FSCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK-SAME: %[[FQSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FINV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto new file mode 100644 index 0000000000..93c822a9f3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK-COUNT-2: pto.vldsx2 {{.*}} "DINTLV_B32" +// CHECK: pto.vdintlv +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto new file mode 100644 index 0000000000..6f2fadfdba --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto new file mode 100644 index 0000000000..fc4ebdc92a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_f16( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_f16( +// CHECK-SAME: %arg0: !pto.vreg<128xf16> +// CHECK-SAME: %arg1: !pto.vreg<128xf16> +// CHECK-SAME: %arg2: !pto.mask +// CHECK: %[[LANE0:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 +// CHECK-SAME: !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[LANE0]] +// CHECK-SAME: !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] : !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto new file mode 100644 index 0000000000..0389c17e25 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto new file mode 100644 index 0000000000..fd6c461b2c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi( + %source: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..466374c65c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_i16_invalid( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: currently supports only 32-bit integer elements diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto new file mode 100644 index 0000000000..8275a80790 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_multichunk( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto new file mode 100644 index 0000000000..51782e8462 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_reduce_minf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_maxf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmax %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmax %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmax %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmax %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmin %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmin %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmin %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmin %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto new file mode 100644 index 0000000000..a926b48d70 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_tail_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<65xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<65xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI{{.*}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires full source physical chunks diff --git a/test/lit/vmi/vmi_to_vpto_reduce_minf.pto b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto new file mode 100644 index 0000000000..96a70a03f3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_minf( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf( +// CHECK: %[[FIRST:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcmin %arg0, %arg2 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vmin %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto new file mode 100644 index 0000000000..1b2cf33ffa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_tail_invalid( + %source: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_addf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_reduce_minf_tail_invalid( + %source: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<64xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_minf lowers through pto.vcmin only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_maxf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts diff --git a/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto new file mode 100644 index 0000000000..ab4f204979 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_relu_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %relu = pto.vmi.relu %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.relu direct lowering requires physical vreg parts with b8/b16/b32 predicate masks and f16/f32 element type +// CHECK-SAME: pto.vrelu direct lowering supports only f16/f32 VMI floating-point element types diff --git a/test/lit/vmi/vmi_to_vpto_scatter.pto b/test/lit/vmi/vmi_to_vpto_scatter.pto new file mode 100644 index 0000000000..4f898e3571 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scatter.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scatter( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scatter( +// CHECK: pto.vscatter %arg0, %arg1, %arg2, %arg3 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scf_for.pto b/test/lit/vmi/vmi_to_vpto_scf_for.pto new file mode 100644 index 0000000000..253432b6dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_for.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_for(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_for( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[P1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[RESULT:.*]]:2 = scf.for +// CHECK-SAME: iter_args(%[[ACC0:.*]] = %[[P0]], %[[ACC1:.*]] = %[[P1]]) +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[N0:.*]] = pto.vadd %[[ACC0]], %[[ACC0]] +// CHECK: %[[N1:.*]] = pto.vadd %[[ACC1]], %[[ACC1]] +// CHECK: scf.yield %[[N0]], %[[N1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scf_if.pto b/test/lit/vmi/vmi_to_vpto_scf_if.pto new file mode 100644 index 0000000000..dcc7497ee4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_if.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %selected : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_if( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[IF:.*]]:4 = scf.if %[[COND]] -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) +// CHECK: pto.vcvt %[[A]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: else +// CHECK: pto.vcvt %[[B]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: pto.vsel %[[IF]]#0, %[[IF]]#0, %[[IF]]#2 +// CHECK: pto.vsel %[[IF]]#1, %[[IF]]#1, %[[IF]]#3 +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shli.pto b/test/lit/vmi/vmi_to_vpto_shli.pto new file mode 100644 index 0000000000..eb5fa7d64d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shli.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shli( + %value: !pto.vmi.vreg<256xi16>, + %amount: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %shifted = pto.vmi.shli %value, %amount + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + return %shifted : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shli( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[SHL0:.*]] = pto.vshl %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[SHL1:.*]] = pto.vshl %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[SHL0]], %[[SHL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shrui.pto b/test/lit/vmi/vmi_to_vpto_shrui.pto new file mode 100644 index 0000000000..46ccbf8d86 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shrui.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shrui( + %value: !pto.vmi.vreg<256xui16>, + %amount: !pto.vmi.vreg<256xui16>) -> !pto.vmi.vreg<256xui16> { + %shifted = pto.vmi.shrui %value, %amount + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<256xui16> + -> !pto.vmi.vreg<256xui16> + return %shifted : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shrui( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xui16> +// CHECK-SAME: -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) +// CHECK-DAG: %[[SHR0:.*]] = pto.vshr %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK-DAG: %[[SHR1:.*]] = pto.vshr %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: return %[[SHR0]], %[[SHR1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto new file mode 100644 index 0000000000..dc237c02dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto @@ -0,0 +1,159 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_identity( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_second_chunk( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } + + func.func @vmi_shuffle_tail_prefix( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<4xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_chunk_swap( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_reverse_one_chunk( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_identity( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_second_chunk( +// CHECK-SAME: %{{[^,]+}}: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D1]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_tail_prefix( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %{{[^)]+}}: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S0]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_chunk_swap( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S1]], %[[S0]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_reverse_one_chunk( +// CHECK-SAME: %[[SRC:[^)]+]]: !pto.vreg<64xf32> +// CHECK-DAG: %[[C63:.*]] = arith.constant 63 : i32 +// CHECK: %[[IDX:.*]] = pto.vci %[[C63]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vselr %[[SRC]], %[[IDX]] : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK-NEXT: return %[[OUT]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto new file mode 100644 index 0000000000..264b7b6a6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_lane0_splat( + %src: !pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_lane0_splat( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP0:.*]] = pto.vdup %arg0, %[[MASK0]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP1:.*]] = pto.vdup %arg0, %[[MASK1]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DUP0]], %[[DUP1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto new file mode 100644 index 0000000000..6e89595596 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto='enable-stable-gather-masked-load=true' 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_stable_gather_masked_load_todo( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + pto.vmi.store %out, %src[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load stable VGATHER-based lowering is reserved for strict masked/tail loads but is not implemented yet diff --git a/test/lit/vmi/vmi_to_vpto_store_deint.pto b/test/lit/vmi/vmi_to_vpto_store_deint.pto new file mode 100644 index 0000000000..cafebbf14d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint2( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint2_multichunk( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B32", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint4( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: pto.vsts %[[D0]], %arg4[%arg5] +// CHECK: pto.vsts %[[D1]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D2]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D3]], %arg4[{{.*}}] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_multichunk( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg2, %arg4[%arg5], "INTLV_B32", %[[MASK0]] +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg1, %arg3, %arg4[{{.*}}], "INTLV_B32", %[[MASK1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto new file mode 100644 index 0000000000..e1068f813c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto new file mode 100644 index 0000000000..653d9b6f33 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint_tail( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_tail.pto b/test/lit/vmi/vmi_to_vpto_store_tail.pto new file mode 100644 index 0000000000..34058b925c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_tail.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_tail( +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %arg0, %arg2[%arg3], %[[FULL_MASK]] +// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %arg1, %arg2[{{.*}}], %[[TAIL_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto new file mode 100644 index 0000000000..34e75012b2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_store_f64_unsupported( + %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %dst: memref<32xf64>, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_stride_load.pto b/test/lit/vmi/vmi_to_vpto_stride_load.pto new file mode 100644 index 0000000000..d30ce58a1f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stride_load.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_stride_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb8, #pto.vmi.layout>) + -> !pto.vreg<256xf8E4M3FN> { + %c1 = arith.constant 1 : i16 + %out = pto.vmi.stride_load %src[%offset], %c1, %c1, %mask + : !pto.ptr, i16, i16, + !pto.vmi.mask<64xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf8E4M3FN, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf8E4M3FN, #pto.vmi.layout>) + -> !pto.vreg<256xf8E4M3FN> + return %part : !pto.vreg<256xf8E4M3FN> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_stride_load( +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> +// CHECK: %[[LOAD:.*]] = pto.vsldb %[[BASE]], %c1{{[^,]*}}, %c1{{[^,]*}}, %arg2 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: return %[[LOAD]] : !pto.vreg<256xf8E4M3FN> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_stride_store.pto b/test/lit/vmi/vmi_to_vpto_stride_store.pto new file mode 100644 index 0000000000..f581ed1bc7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stride_store.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_stride_store( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : i16 + %c4 = arith.constant 4 : i16 + pto.vmi.stride_store %value, %dst[%c0], %c2, %c4, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, i16, i16, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_stride_store( +// CHECK: %[[BASE:.*]] = pto.addptr %arg1, %c0 : -> +// CHECK: %{{.*}} = pto.vsstb %arg0, %[[BASE]], %c2{{[^,]*}}, %c4{{[^,]*}}, %arg2 : !pto.vreg<64xf32>, !pto.ptr, i16, i16, !pto.mask -> !pto.ptr +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_sub_mul.pto b/test/lit/vmi/vmi_to_vpto_sub_mul.pto new file mode 100644 index 0000000000..d76a6bfd3c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_sub_mul.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_subf_mulf( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %wa = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %wb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %diff = pto.vmi.subf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %diff, %prod : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_subi_muli( + %a: !pto.vmi.vreg<128xi32>, + %b: !pto.vmi.vreg<128xi32>) + -> (!pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32>) { + %diff = pto.vmi.subi %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %prod = pto.vmi.muli %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return %diff, %prod : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_subf_mulf( +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[SUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[SUB0]], %[[SUB1]], %[[MUL0]], %[[MUL1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_subi_muli( +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[ISUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[ISUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ISUB0]], %[[ISUB1]], %[[IMUL0]], %[[IMUL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf.pto b/test/lit/vmi/vmi_to_vpto_truncf.pto new file mode 100644 index 0000000000..edac1ec223 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_f32_to_f16( + %even: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %odd: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %wide = pto.vmi.addf %even, %odd + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_truncf_f32_tail_to_f16( + %wide: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_truncf_f32_to_f16_multichunk( + %wide: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_tail_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16_multichunk( +// CHECK: %[[EVEN0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT0:.*]] = pto.vor %[[EVEN0]], %[[ODD0]] +// CHECK: %[[EVEN1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT1:.*]] = pto.vor %[[EVEN1]], %[[ODD1]] +// CHECK: return %[[OUT0]], %[[OUT1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto new file mode 100644 index 0000000000..9ad0755cc4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( + %input: !pto.vmi.vreg<128xf32>, + %dst: !pto.ptr, + %off: index) { + %packed = pto.vmi.truncf %input + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + pto.vmi.store %packed, %dst[%off] + : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[D2:.*]] = pto.vmi.ensure_layout %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[PACKED:.*]] = pto.vmi.truncf %[[D2]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[PACKED]] + +// LOWER-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER-NOT: part = "P1" +// LOWER-NOT: part = "P3" +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto new file mode 100644 index 0000000000..324b01ea5b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_unsupported_shape_invalid( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) { + %narrow = pto.vmi.truncf %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf supports only f32 deinterleaved=2 source parts +// CHECK-SAME: dense f16 results +// CHECK-SAME: f32 source layouts whose factor times the result lane_stride matches the fp8-like narrowing factor diff --git a/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto new file mode 100644 index 0000000000..c87af13167 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_trunci_i32_to_i8_invalid( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xi8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xi8, #pto.vmi.layout>) + -> !pto.vreg<256xi8> + return %p : !pto.vreg<256xi8> + } +} + +// CHECK: VMI-UNSUPPORTED +// CHECK: pto.vmi.trunci supports integer deinterleaved source layouts +// CHECK: 8-bit integer narrowing requires unsigned i8 result diff --git a/test/lit/vmi/vmi_to_vpto_type_arity.pto b/test/lit/vmi/vmi_to_vpto_type_arity.pto new file mode 100644 index 0000000000..e99e8e9ea0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_arity.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_arity_contiguous_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint2_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_contiguous_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint4( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint2_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto new file mode 100644 index 0000000000..afc8502caf --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = [{nested = !pto.vmi.vreg<128xf32, #pto.vmi.layout>}] +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto new file mode 100644 index 0000000000..c115c1c3d8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_only.pto b/test/lit/vmi/vmi_to_vpto_type_only.pto new file mode 100644 index 0000000000..777afaf124 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_only.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_only( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_only( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_unary_math.pto b/test/lit/vmi/vmi_to_vpto_unary_math.pto new file mode 100644 index 0000000000..5a4419bad2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unary_math.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unary_math( + %value: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg, %sqrt, %exp, %ln, %relu + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %abs : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unary_math( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT0:.*]] = pto.vsqrt %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT1:.*]] = pto.vsqrt %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP0:.*]] = pto.vexp %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP1:.*]] = pto.vexp %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN0:.*]] = pto.vln %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN1:.*]] = pto.vln %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU0:.*]] = pto.vrelu %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU1:.*]] = pto.vrelu %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]], %[[SQRT0]], %[[SQRT1]], %[[EXP0]], %[[EXP1]], %[[LN0]], %[[LN1]], %[[RELU0]], %[[RELU1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[ABSF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[ABSF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ABSF0]], %[[ABSF1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^)]+]]: !pto.vreg<64xi32> +// CHECK-SAME: -> !pto.vreg<64xi32> +// CHECK: %[[ABSI:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ABSI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto new file mode 100644 index 0000000000..9bf8f25949 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unrealized_cast_residual_invalid( + %arg0: i32) -> f32 { + %0 = builtin.unrealized_conversion_cast %arg0 + : i32 to f32 + return %0 : f32 + } +} + +// CHECK: VMI-RESIDUAL-OP: unrealized conversion cast remains after vmi-to-vpto diff --git a/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto new file mode 100644 index 0000000000..df51608b08 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unsupported_op_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %shuffled = "pto.vmi.shuffle"(%a) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.shuffle requires physical chunk forwarding or lane0 splat or vci-materializable vselr indices +// CHECK-SAME: forwarding: +// CHECK-SAME: lane0 splat: +// CHECK-SAME: vselr: diff --git a/test/lit/vmi/vmi_truncf_direction_invalid.pto b/test/lit/vmi/vmi_truncf_direction_invalid.pto new file mode 100644 index 0000000000..934f1e4ba3 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_direction_invalid(%source: !pto.vmi.vreg<128xf16>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires result element type to be narrower than source element type diff --git a/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..56e07a9892 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf32>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_truncf_rounding_token_invalid.pto b/test/lit/vmi/vmi_truncf_rounding_token_invalid.pto new file mode 100644 index 0000000000..cc191cacea --- /dev/null +++ b/test/lit/vmi/vmi_truncf_rounding_token_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_rounding_token_invalid( + %source: !pto.vmi.vreg<256xf32>) { + %result = pto.vmi.truncf %source {rounding = "R"} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> + return + } +} + +// CHECK: rounding attr must be A or H diff --git a/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto b/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto new file mode 100644 index 0000000000..0847c26a10 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_rounding_unsupported_invalid( + %source: !pto.vmi.vreg<128xf32>) { + %result = pto.vmi.truncf %source {rounding = "H"} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: rounding attr is currently only supported for f32 to !pto.hif8 truncf diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto new file mode 100644 index 0000000000..41c795bcdb --- /dev/null +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module attributes { + pto.vmi_contiguous = #pto.vmi.layout, + pto.vmi_deinterleaved2 = #pto.vmi.layout, + pto.vmi_deinterleaved4 = #pto.vmi.layout, + pto.vmi_deinterleaved4_block8 = + #pto.vmi.layout, + pto.vmi_group_slots8 = #pto.vmi.layout, + pto.vmi_group_slots_partial = + #pto.vmi.layout +} { + func.func @vmi_type_attr_parse( + %surface: !pto.vmi.vreg<128xf32>, + %contiguous: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %group_slots8: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %group_slots_partial: !pto.vmi.vreg<10xf32, #pto.vmi.layout>, + %surface_mask: !pto.vmi.mask<128xpred>, + %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %mask_b32_block8: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: pto.vmi_contiguous = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved2 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4_block8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots_partial = #pto.vmi.layout +// CHECK-LABEL: func.func @vmi_type_attr_parse( +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<10xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_type_element_count_invalid.pto b/test/lit/vmi/vmi_type_element_count_invalid.pto new file mode 100644 index 0000000000..a7548528c9 --- /dev/null +++ b/test/lit/vmi/vmi_type_element_count_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_type_element_count_invalid( + %arg0: !pto.vmi.vreg<0xf32>) { + return + } +} + +// CHECK: expected a positive element count diff --git a/test/lit/vmi/vmi_unary_math_integer_invalid.pto b/test/lit/vmi/vmi_unary_math_integer_invalid.pto new file mode 100644 index 0000000000..8f3af3092e --- /dev/null +++ b/test/lit/vmi/vmi_unary_math_integer_invalid.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_sqrt_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.sqrt' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_exp_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.exp' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_ln_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.ln' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_relu_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.relu' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_unpack_arity_invalid.pto b/test/lit/vmi/vmi_unpack_arity_invalid.pto new file mode 100644 index 0000000000..5cd224a6e6 --- /dev/null +++ b/test/lit/vmi/vmi_unpack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_unpack_arity_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %p0 = "pto.vmi.unpack"(%a) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/lit/vpto/arith_select_vpto_llvm.pto b/test/lit/vpto/arith_select_vpto_llvm.pto new file mode 100644 index 0000000000..b32a7fe0de --- /dev/null +++ b/test/lit/vpto/arith_select_vpto_llvm.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @arith_select_vreg(%cond: i1, %lhs_scalar: f32, %rhs_scalar: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vdup %lhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vdup %rhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %chosen = arith.select %cond, %lhs, %rhs : !pto.vreg<64xf32> + pto.vsts %chosen, %dst[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } + + func.func @arith_select_mask(%cond: i1, %value: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %tail = pto.pge_b32 "PAT_VL4" : !pto.mask + %chosen_mask = arith.select %cond, %all, %tail : !pto.mask + %vec = pto.vdup %value, %all + : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %dst[%c0], %chosen_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @arith_select_vreg_mix_aiv +// CHECK: %[[LHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[RHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[CHOSEN:.*]] = llvm.select %arg0, %[[LHS]], %[[RHS]] : i1, vector<64xf32> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32(%[[CHOSEN]] + +// CHECK-LABEL: llvm.func @arith_select_mask_mix_aiv +// CHECK: %[[ALL:.*]] = llvm.call @llvm.hivm.pset.b32 +// CHECK: %[[TAIL:.*]] = llvm.call @llvm.hivm.pge.b32 +// CHECK: %[[CHOSEN_MASK:.*]] = llvm.select %arg0, %[[ALL]], %[[TAIL]] : i1, vector<256xi1> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CHOSEN_MASK]]) diff --git a/test/lit/vpto/vgather2_u16_vpto_llvm.pto b/test/lit/vpto/vgather2_u16_vpto_llvm.pto new file mode 100644 index 0000000000..d2e8f983de --- /dev/null +++ b/test/lit/vpto/vgather2_u16_vpto_llvm.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_u16(%src: !pto.ptr, + %idx: !pto.vreg<128xui16>, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %out = pto.vgather2 %src, %idx, %mask + : !pto.ptr, !pto.vreg<128xui16>, !pto.mask + -> !pto.vreg<128xui16> + pto.vsts %out, %dst[%c0], %mask + : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @vgather2_u16_mix_aiv +// CHECK: %[[IDX:.*]] = llvm.bitcast %arg1 : vector<128xi16> to vector<64xi32> +// CHECK: llvm.call @llvm.hivm.vgather2.v300.v128u16(%arg0, %[[IDX]], diff --git a/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto b/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto new file mode 100644 index 0000000000..65cf7e6223 --- /dev/null +++ b/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_fp4_e1_packed_surface_invalid( + %arg0: !pto.vmi.vreg<256x!pto.f4E1M2x2>) attributes {pto.kernel} { + return + } +} + +// CHECK: error: '!pto.vmi.vreg<256x!pto.f4E1M2x2>' uses a packed FP4 physical pair type as a VMI logical element type +// CHECK-SAME: packed FP4 input/output is not a supported VMI surface diff --git a/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto b/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto new file mode 100644 index 0000000000..18e8f6fd30 --- /dev/null +++ b/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_fp4_packed_surface_invalid( + %arg0: !pto.vmi.vreg<256x!pto.f4E2M1x2>) attributes {pto.kernel} { + return + } +} + +// CHECK: error: '!pto.vmi.vreg<256x!pto.f4E2M1x2>' uses a packed FP4 physical pair type as a VMI logical element type +// CHECK-SAME: packed FP4 input/output is not a supported VMI surface diff --git a/test/lit/vpto/vmi_sitofp.pto b/test/lit/vpto/vmi_sitofp.pto new file mode 100644 index 0000000000..fca0e63698 --- /dev/null +++ b/test/lit/vpto/vmi_sitofp.pto @@ -0,0 +1,42 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @vmi_sitofp_kernel +// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_sitofp_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<64xi32> + %f = pto.vmi.sitofp %x + : !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %f, %ub_dst[%c0] + : !pto.vmi.vreg<64xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/lit/vpto/vmi_truncf_hif8.pto b/test/lit/vpto/vmi_truncf_hif8.pto new file mode 100644 index 0000000000..a638759e01 --- /dev/null +++ b/test/lit/vpto/vmi_truncf_hif8.pto @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @vmi_truncf_hif8_default_kernel +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK-NOT: part = "P1" +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask +// CHECK-LABEL: func.func @vmi_truncf_hif8_hybrid_kernel +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK-NOT: part = "P1" +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_truncf_hif8_default_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %h = pto.vmi.truncf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> + pto.vmi.store %h, %ub_dst[%c0] + : !pto.vmi.vreg<256x!pto.hif8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @vmi_truncf_hif8_hybrid_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %h = pto.vmi.truncf %x {rounding = "H"} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> + pto.vmi.store %h, %ub_dst[%c0] + : !pto.vmi.vreg<256x!pto.hif8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto b/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto new file mode 100644 index 0000000000..4b610b1ec9 --- /dev/null +++ b/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vpto-normalize-equivalent-vcvt -canonicalize -cse | FileCheck %s + +module { + func.func @e2b_load(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] {dist = "E2B_B16"} + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @brc_load(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] {dist = "BRC_B16"} + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @scalar_broadcast(%seed: f16) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %broadcast = pto.vbr %seed : f16 -> !pto.vreg<128xf16> + %even = pto.vcvt %broadcast, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %broadcast, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @normal_load_is_not_changed(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @masked_broadcast_is_not_changed(%seed: f16) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_VL1" : !pto.mask + %broadcast = pto.vbr %seed : f16 -> !pto.vreg<128xf16> + %even = pto.vcvt %broadcast, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %broadcast, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @e2b_load +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @brc_load +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @scalar_broadcast +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @normal_load_is_not_changed +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @masked_broadcast_is_not_changed +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py new file mode 100644 index 0000000000..7df1eedef3 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SCALE = np.float32(0.5) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src + SCALE + golden_sum = np.sum(src * SCALE, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto new file mode 100644 index 0000000000..648a98b0a9 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_broadcast_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp new file mode 100644 index 0000000000..21e26d6cf5 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_broadcast_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_broadcast_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp new file mode 100644 index 0000000000..b43a794cdb --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_broadcast_dense_group_users_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py new file mode 100644 index 0000000000..837961af76 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v2", "v3"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py new file mode 100644 index 0000000000..6e5edd801a --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + copy_out = np.full(INPUT_ELEMS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden_sum[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + copy_out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + src.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto new file mode 100644 index 0000000000..48ce7738a5 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dense_group_reduce_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %copy_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp new file mode 100644 index 0000000000..1249378267 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dense_group_reduce_multi_consumer_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *copy); + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream) { + vmi_dense_group_reduce_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)copy); +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp new file mode 100644 index 0000000000..0482d8339d --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kCopyElems = kInputElems; + size_t srcBytes = kInputElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t copyBytes = kCopyElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + float *copyHost = nullptr; + float *copyDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", copyBytes, copyHost, copyBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dense_group_reduce_multi_consumer_kernel(srcDevice, sumDevice, + copyDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", copyHost, copyBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(copyDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(copyHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py new file mode 100644 index 0000000000..8c3eb7acea --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 23 +SCALE = np.float32(2.0) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = src[:LOGICAL_ELEMS].astype(np.float32) * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..81b8640c0a --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f16_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..3c329a34bb --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f16_to_f32_tail_kernel(__gm__ half *src, __gm__ float *dst); + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_dequant_f16_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp new file mode 100644 index 0000000000..7797fe7fb0 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f16_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py new file mode 100644 index 0000000000..b53b4b2ba9 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SCALE = np.float32(2.0) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = decoded[:LOGICAL_ELEMS] * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..bddf6b0f06 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f8_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %ub_src_f8[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..02688457e3 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f8_to_f32_tail_kernel(__gm__ uint8_t *src, __gm__ float *dst); + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream) { + vmi_dequant_f8_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp new file mode 100644 index 0000000000..ee62749258 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint8_t *srcHost = nullptr; + uint8_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f8_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py new file mode 100644 index 0000000000..22aff69b5d --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + output = np.fromfile("v3.bin", dtype=np.uint16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed v3.bin: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3.bin idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py new file mode 100644 index 0000000000..0c09bb49d7 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +SOURCE_ELEMS = 512 +LOGICAL_LANES = 300 +BINS = 256 + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + src = (np.arange(SOURCE_ELEMS, dtype=np.uint16) % BINS).astype(np.uint8) + acc = (np.arange(BINS, dtype=np.uint16) % np.uint16(5)).astype(np.uint16) + dst = np.full(BINS, np.uint16(0xcccc), dtype=np.uint16) + + counts = np.bincount(src[:LOGICAL_LANES].astype(np.int64), minlength=BINS) + golden = (acc.astype(np.uint32) + counts.astype(np.uint32)).astype(np.uint16) + + src.tofile(output_dir / "v1.bin") + acc.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto new file mode 100644 index 0000000000..4fb1fe531c --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dhist_tail_mask_store_kernel( + %src_gm: !pto.ptr, %acc_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c300 = arith.constant 300 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_acc = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %acc_gm, %ub_acc, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %source = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<512xui8> + %acc = pto.vmi.load %ub_acc[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xui16> + %mask = pto.vmi.create_mask %c300 : index -> !pto.vmi.mask<512xpred> + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<512xui8>, + !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<256xui16> + pto.vmi.store %hist, %ub_dst[%c0] + : !pto.vmi.vreg<256xui16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp new file mode 100644 index 0000000000..4031c8131e --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dhist_tail_mask_store_kernel(__gm__ uint8_t *src, __gm__ uint16_t *acc, + __gm__ uint16_t *dst); + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream) { + vmi_dhist_tail_mask_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint16_t *)acc, (__gm__ uint16_t *)dst); +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp new file mode 100644 index 0000000000..aa1288ab26 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream); + +int main() { + constexpr size_t kSourceElems = 512; + constexpr size_t kBins = 256; + size_t srcBytes = kSourceElems * sizeof(uint8_t); + size_t accBytes = kBins * sizeof(uint16_t); + size_t dstBytes = kBins * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint16_t *accHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint16_t *accDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&accHost), accBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&accDevice, accBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", accBytes, accHost, accBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(accDevice, accBytes, accHost, accBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVmi_dhist_tail_mask_store_kernel(srcDevice, accDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(accDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(accHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..1c4918951e --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_create_group_mask_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr, %active_cols_i32: i32) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active_cols = arith.index_cast %active_cols_i32 : i32 to index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5865140b26 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_create_group_mask_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum, int activeCols); + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream) { + vmi_dynamic_create_group_mask_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum, + activeCols); +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..7bd86defb1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr int kActiveCols = 25; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel( + srcDevice, copyDevice, sumDevice, kActiveCols, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py new file mode 100644 index 0000000000..d00c9b8b26 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) + if close: + return True + diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + return False + + +def main() -> None: + if not check_f32("v2", 1e-4, 1e-4) or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py new file mode 100644 index 0000000000..9034fe8d42 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + golden_out8 = np.empty((ROWS, GROUP_SIZE), dtype=np.uint8) + for row in range(ROWS): + value_idx = row % len(VALUES) + if row == 0: + src[row, :] = np.tile(VALUES, GROUP_SIZE // len(VALUES)) + golden_out8[row, :] = np.tile(F8E4M3FN_BYTES, GROUP_SIZE // len(F8E4M3FN_BYTES)) + else: + src[row, :] = VALUES[value_idx] + golden_out8[row, :] = F8E4M3FN_BYTES[value_idx] + + golden_sum = np.sum(src, axis=1, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out8 = np.full(ROWS * GROUP_SIZE, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out8.tofile(output_dir / "v3.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_out8.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto new file mode 100644 index 0000000000..1a0f7f0d42 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f32_to_f8_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x32 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %ub_out8_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp new file mode 100644 index 0000000000..eef7fac9d0 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f32_to_f8_store_reduce_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ uint8_t *out8); + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream) { + vmi_f32_to_f8_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp new file mode 100644 index 0000000000..1e3e7e8a86 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kSrcElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kOut8Elems = kSrcElems; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t out8Bytes = kOut8Elems * sizeof(uint8_t); + float *srcHost = nullptr; + float *sumHost = nullptr; + uint8_t *out8Host = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + uint8_t *out8Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f32_to_f8_store_reduce_kernel(srcDevice, sumDevice, out8Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", out8Host, out8Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(out8Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(out8Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f8-compute-f8/compare.py b/test/vpto/cases/vmi/f8-compute-f8/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/golden.py b/test/vpto/cases/vmi/f8-compute-f8/golden.py new file mode 100644 index 0000000000..e150b09545 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +F8E4M3FN_TIMES2 = np.array([0x00, 0x40, 0xC0, 0x38, 0x48, 0xC8, 0x50, 0xD0], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(F8E4M3FN_BYTES) - 1) // len(F8E4M3FN_BYTES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + golden = np.tile(F8E4M3FN_TIMES2, repeats)[:ELEMS].astype(np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/kernel.pto b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto new file mode 100644 index 0000000000..568cf5fbde --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f8_compute_f8_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x8 = pto.vmi.load %ub_src_f8[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %ub_dst_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/launch.cpp b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp new file mode 100644 index 0000000000..63b5269670 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f8_compute_f8_kernel(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, + void *stream) { + vmi_f8_compute_f8_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/main.cpp b/test/vpto/cases/vmi/f8-compute-f8/main.cpp new file mode 100644 index 0000000000..fffc2d6e65 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/main.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t bytes = kElems * sizeof(uint8_t); + uint8_t *srcHost = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", bytes, srcHost, bytes); + ReadFile("./v2.bin", bytes, dstHost, bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, bytes, srcHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, bytes, dstHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f8_compute_f8_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, bytes, dstDevice, bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py new file mode 100644 index 0000000000..da96a2ff71 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_sum = np.fromfile("golden_v2.bin", dtype=np.float32) + output_sum = np.fromfile("v2.bin", dtype=np.float32) + if golden_sum.shape != output_sum.shape or not np.allclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden_sum[idx] if idx >= 0 else 'n/a'} " + f"output={output_sum[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + golden_dense = np.fromfile("golden_v3.bin", dtype=np.float16) + output_dense = np.fromfile("v3.bin", dtype=np.float16) + if golden_dense.shape != output_dense.shape or not np.array_equal(golden_dense, output_dense): + diff = np.nonzero(golden_dense.view(np.uint16) != output_dense.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden_dense[idx] if idx >= 0 else 'n/a'} " + f"output={output_dense[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py new file mode 100644 index 0000000000..a238aaf082 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) +SUM_SENTINEL = np.float32(-911.0) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + sum_out = np.full(ROWS, SUM_SENTINEL, dtype=np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32) + golden_sum[row] = np.sum(values * row_sum, dtype=np.float32) + golden_dense[begin : begin + GROUP_SIZE] = row_sum.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto new file mode 100644 index 0000000000..f81b4dfd24 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_broadcast_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %b_for_mul = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %ub_dense[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp new file mode 100644 index 0000000000..2a562a57e3 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_broadcast_multi_consumer_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ half *dense); + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream) { + vmi_group_broadcast_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ half *)dense); +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp new file mode 100644 index 0000000000..dc39a0c47d --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream); + +int main() { + constexpr size_t kElems = 128; + constexpr size_t kRows = 8; + size_t srcBytes = kElems * sizeof(float); + size_t sumBytes = kRows * sizeof(float); + size_t denseBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + uint16_t *denseHost = nullptr; + uint16_t *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_broadcast_multi_consumer_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py new file mode 100644 index 0000000000..5c25033808 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ROW_STRIDE = 24 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.25, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto new file mode 100644 index 0000000000..6de55bf7fb --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s16_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp new file mode 100644 index 0000000000..ef8fa0d082 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s16_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s16_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp new file mode 100644 index 0000000000..414e34200e --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s16_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py new file mode 100644 index 0000000000..8cb473640d --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.zeros(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-1.0, 1.0, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto new file mode 100644 index 0000000000..e73c083e55 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_broadcast_reduce_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scaled_sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp new file mode 100644 index 0000000000..d9218a9389 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_broadcast_reduce_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream) { + vmi_group_load_s32_stride_broadcast_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp new file mode 100644 index 0000000000..b994c2192f --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py new file mode 100644 index 0000000000..efe2d5f3b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.75, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto new file mode 100644 index 0000000000..609ebb6891 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp new file mode 100644 index 0000000000..9443a9cfb3 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s32_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp new file mode 100644 index 0000000000..b67ef78981 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/compare.py b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py new file mode 100644 index 0000000000..dc3a89703c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(output_name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(output_name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {output_name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {output_name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + check("v6.bin", "golden_v6.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/golden.py b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py new file mode 100644 index 0000000000..24071a1b49 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((ROWS, cols), dtype=np.float32) + for row in range(ROWS): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def write_case(output_dir: Path, matrix: np.ndarray, src_name: str, dst_name: str, golden_name: str) -> None: + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(matrix, axis=1, dtype=np.float32).astype(np.float32) + matrix.reshape(-1).tofile(output_dir / src_name) + dst.tofile(output_dir / dst_name) + golden.tofile(output_dir / golden_name) + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + write_case(output_dir, fill_matrix(8, -0.5, 0.03125), "v1.bin", "v4.bin", "golden_v4.bin") + write_case(output_dir, fill_matrix(16, -0.75, 0.046875), "v2.bin", "v5.bin", "golden_v5.bin") + write_case(output_dir, fill_matrix(32, -0.875, 0.0625), "v3.bin", "v6.bin", "golden_v6.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto new file mode 100644 index 0000000000..123ef977f1 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_basic_store_kernel(%src8_gm: !pto.ptr, + %src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %dst8_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src16 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src8_gm, %ub_src8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask8 = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x8 = pto.vmi.load %ub_src8[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %sum8 = pto.vmi.group_reduce_addf %x8, %mask8 {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum8, %ub_dst8[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst8, %dst8_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp new file mode 100644 index 0000000000..a7304f9a15 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_basic_store_kernel(__gm__ float *src8, + __gm__ float *src16, + __gm__ float *src32, + __gm__ float *dst8, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream) { + vmi_group_reduce_basic_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src8, (__gm__ float *)src16, (__gm__ float *)src32, + (__gm__ float *)dst8, (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp new file mode 100644 index 0000000000..4ddb71365b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kSrc8Elems = kRows * 8; + constexpr size_t kSrc16Elems = kRows * 16; + constexpr size_t kSrc32Elems = kRows * 32; + constexpr size_t kOutputElems = kRows; + size_t src8Bytes = kSrc8Elems * sizeof(float); + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src8Host = nullptr; + float *src16Host = nullptr; + float *src32Host = nullptr; + float *dst8Host = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src8Device = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *dst8Device = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src8Host), src8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst8Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src8Device, src8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst8Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src8Bytes, src8Host, src8Bytes); + ReadFile("./v2.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v3.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v4.bin", dstBytes, dst8Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v6.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src8Device, src8Bytes, src8Host, src8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst8Device, dstBytes, dst8Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_basic_store_kernel( + src8Device, src16Device, src32Device, dst8Device, dst16Device, + dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst8Host, dstBytes, dst8Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst8Host, dstBytes); + WriteFile("./v5.bin", dst16Host, dstBytes); + WriteFile("./v6.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src8Device); + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(dst8Device); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src8Host); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(dst8Host); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py new file mode 100644 index 0000000000..fbba5d605b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py new file mode 100644 index 0000000000..beed48b5da --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.float16) + base = np.array([-3, -2, -1, 0, 1, 2, 3, 4], dtype=np.float16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, np.float16(-17), dtype=np.float16) + golden = np.sum(src, axis=1, dtype=np.float16).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto new file mode 100644 index 0000000000..f586b23278 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_addf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf16> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp new file mode 100644 index 0000000000..8cfb1e58b5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_addf_store_kernel(__gm__ half *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_f16_addf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp new file mode 100644 index 0000000000..7a92e1a331 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t dstBytes = kOutputElems * sizeof(uint16_t); + uint16_t *srcHost = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_addf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..69fbe13344 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 2 +ROW_ELEMS = 256 +ROW_STRIDE = 320 +TOTAL_ELEMS = ROWS * ROW_STRIDE +F16_VALUES = np.array([0.125, 0.25], dtype=np.float16) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ROW_ELEMS + len(VALUES) - 1) // len(VALUES) + row_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ROW_ELEMS].astype(np.uint8) + row_decoded_f8 = np.tile(VALUES, repeats)[:ROW_ELEMS].astype(np.float32) + + src_f16 = np.zeros(TOTAL_ELEMS, dtype=np.float16) + src_f8 = np.zeros(TOTAL_ELEMS, dtype=np.uint8) + dst = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + + for row in range(ROWS): + begin = row * ROW_STRIDE + end = begin + ROW_ELEMS + src_f16[begin:end] = F16_VALUES[row] + src_f8[begin:end] = np.roll(row_f8, row) + decoded_f8 = np.roll(row_decoded_f8, row) + reduction = np.sum(src_f16[begin:end].astype(np.float32), dtype=np.float32) + golden[begin:end] = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..34c6a1bf1e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c320 = arith.constant 320 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c256_i64 = arith.constant 256 : i64 + %c320_i64 = arith.constant 320 : i64 + %c512_i64 = arith.constant 512 : i64 + %c640_i64 = arith.constant 640 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c2_i64, %c640_i64, %c640_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c2_i64, %c320_i64, %c320_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %src_f16 = pto.vmi.group_load %ub_f16[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 + : !pto.vmi.vreg<512xf16> -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %src_f16_f32, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<2xf32> + %src_f8 = pto.vmi.group_load %ub_f8[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c320 {num_groups = 2} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c2_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..03bf4d7e8f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream) { + vmi_group_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e5769e3978 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 2; + constexpr size_t kRowStride = 320; + constexpr size_t kElems = kRows * kRowStride; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..00097384f0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int16) + base = np.array([-5, -3, -1, 0, 2, 4, 6, 8], dtype=np.int16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..e4a7d10eeb --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i16_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xi16> + %x32 = pto.vmi.extsi %x16 : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..255de845bd --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i16_extsi_i32_addi_store_kernel(__gm__ int16_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i16_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..277a78662f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int16_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int16_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int16_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py new file mode 100644 index 0000000000..4153e74342 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 8 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int32) + for row in range(ROWS): + src[row, :] = np.arange(COLS, dtype=np.int32) + row * 3 - 5 + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src, axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..d311a4a932 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xi32> + %sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..5783bfd5a8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i32_addi_store_kernel(__gm__ int32_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream) { + vmi_group_reduce_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp new file mode 100644 index 0000000000..385f3ae909 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 64; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int32_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int32_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int32_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i32_addi_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py new file mode 100644 index 0000000000..1aa24c830e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 8 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int32) + for row in range(ROWS): + base = np.arange(COLS, dtype=np.int32) * ((row % 3) + 1) + src[row, :] = base - row * 5 - 9 + src[row, COLS // 2] = row * 11 - 17 + src[row, COLS - 1] = 23 - row * 7 + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.max(src, axis=1).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto new file mode 100644 index 0000000000..1c594cf5c0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i32_maxi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xi32> + %sum = pto.vmi.group_reduce_maxi %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<8xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp new file mode 100644 index 0000000000..7a7c0bacb9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i32_maxi_store_kernel(__gm__ int32_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i32_maxi_store_kernel(int32_t *src, int32_t *dst, + void *stream) { + vmi_group_reduce_i32_maxi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp new file mode 100644 index 0000000000..0aa2835503 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i32_maxi_store_kernel(int32_t *src, int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 64; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int32_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int32_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int32_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i32_maxi_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..76d46fff4c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row, :] = ((np.arange(COLS, dtype=np.int16) * 3 + row * 5) % 41 - 20).astype( + np.int8 + ) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..04a1afda13 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i8_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x8 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xi8> + %x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..1e046a8eb5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i8_extsi_i32_addi_store_kernel(__gm__ int8_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i8_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..cef9801b4d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 256; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int8_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int8_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int8_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..05510a7bd9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..3f0243d8e1 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..f180d41359 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..f3b88b52fa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..f8e59f415f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.5, 0.75, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(21.0, 24.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.046875) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.5) + reduction = np.sum(src[row, :ACTIVE], dtype=np.float32) + golden[row] = np.sum(src[row, :ACTIVE] * reduction, dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..eb6ebedee1 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..bd5cc88024 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + __gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..b87811e20c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..808e7e271f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.75, 0.375, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(25.0, 28.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.0625) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(2.0) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..04af55f5bc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..745e836949 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_group_mask_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..3d55e6ccfa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..d3f358ba45 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +ROW_STRIDE = 24 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(ROWS * ROW_STRIDE, np.float32(99.0), dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + active_base = np.linspace(-0.625, 0.5, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(31.0, 35.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + src[begin : begin + ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[begin + ACTIVE : begin + GROUP_SIZE] = inactive_base + np.float32(row) + golden[row] = np.sum(src[begin : begin + ACTIVE], dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..f22ce53896 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c24 = arith.constant 24 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..ef2e2aaef2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_stride_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_stride_group_mask_tail_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..4a6af8cac7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py new file mode 100644 index 0000000000..2010556d20 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32).astype(np.float16) + golden[begin : begin + GROUP_SIZE] = row_sum + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto new file mode 100644 index 0000000000..7063b4e5ef --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_truncf_broadcast_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> + %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows, %ub_dst[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp new file mode 100644 index 0000000000..21b6e43c3d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_truncf_broadcast_store_kernel(__gm__ float *src, + __gm__ half *dst); + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream) { + vmi_group_reduce_s16_truncf_broadcast_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp new file mode 100644 index 0000000000..13fe482440 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(srcDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py new file mode 100644 index 0000000000..1614628a0b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +BIAS = np.float32(0.25) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values + BIAS, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto new file mode 100644 index 0000000000..f4c7b4f18a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_add_bias_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %bias = arith.constant 2.500000e-01 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %biased = pto.vmi.addf %x, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %biased, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp new file mode 100644 index 0000000000..b5526b9b23 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_add_bias_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_add_bias_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp new file mode 100644 index 0000000000..5c85668ceb --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_add_bias_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..aef1ece1b4 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..aa20ef0c55 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..e8decb88f5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..eba17dbdd0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py new file mode 100644 index 0000000000..409f321f7d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto new file mode 100644 index 0000000000..271fa80ef7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_cf_join_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %zero = arith.constant 0.000000e+00 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond = arith.cmpi eq, %c0, %c0 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = scf.if %cond -> (!pto.vmi.vreg<256xf32>) { + %then_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + scf.yield %then_x : !pto.vmi.vreg<256xf32> + } else { + %else_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %else_y = pto.vmi.addf %else_x, %zero_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.yield %else_y : !pto.vmi.vreg<256xf32> + } + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp new file mode 100644 index 0000000000..4204a6ca52 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_cf_join_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp new file mode 100644 index 0000000000..a504036a2e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_cf_join_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py new file mode 100644 index 0000000000..a00c19efbe --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto new file mode 100644 index 0000000000..0b1ea79141 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_multitile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<16xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 16} + : !pto.vmi.vreg<16xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp new file mode 100644 index 0000000000..88c109d7d0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_multitile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_multitile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp new file mode 100644 index 0000000000..f30ea2a367 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_multitile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py new file mode 100644 index 0000000000..8c5fc67aca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py new file mode 100644 index 0000000000..a521122803 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +PHYSICAL_ROWS = 8 +ACTIVE_ROWS = 6 +GROUP_SIZE = 32 +INPUT_ELEMS = PHYSICAL_ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + golden = np.zeros(PHYSICAL_ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(PHYSICAL_ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + if row < ACTIVE_ROWS: + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto new file mode 100644 index 0000000000..3e78d88df0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_tail_full_tile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp new file mode 100644 index 0000000000..5dd1b3c148 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_tail_full_tile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_tail_full_tile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp new file mode 100644 index 0000000000..5cd1b690d2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kPhysicalRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kPhysicalRows * kGroupSize; + constexpr size_t kOutputElems = kPhysicalRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..24fa390b6c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +INPUT_ELEMS = ROWS * GROUP_SIZE +OUTPUT_STRIDE = 8 +OUTPUT_ELEMS = ROWS * OUTPUT_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..1712ef8025 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<512xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..ba45139736 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s64_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..91e2c97119 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py new file mode 100644 index 0000000000..6d0d25229a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +RHS_STRIDE = 8 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base_row + np.float32(row) * np.float32(0.03125) + + rhs = np.linspace(-0.75, 0.75, ROWS * RHS_STRIDE, dtype=np.float32) + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = ( + np.sum(src[row, :], dtype=np.float32) + rhs[row * RHS_STRIDE] + ) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto new file mode 100644 index 0000000000..5765b56274 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_slot_add_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %rhs = pto.vmi.group_slot_load %ub_rhs[%c0], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.addf %sum, %rhs + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp new file mode 100644 index 0000000000..7225148ff7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_slot_add_store_kernel(__gm__ float *src, + __gm__ float *rhs, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream) { + vmi_group_reduce_s64_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp new file mode 100644 index 0000000000..1f5acfaa5c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kRhsStride = 8; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kRhsElems = kRows * kRhsStride; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_slot_add_store_kernel(srcDevice, rhsDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py new file mode 100644 index 0000000000..83ac2d015e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 6 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = np.sum(src[row, :], dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto new file mode 100644 index 0000000000..1073e351c8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_tail_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c192_i64 = arith.constant 192 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1536_i64 + nburst(%c1_i64, %c1536_i64, %c1536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<6xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c8 {num_groups = 6} + : !pto.vmi.vreg<6xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp new file mode 100644 index 0000000000..afdf98b76d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_tail_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s64_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp new file mode 100644 index 0000000000..3223b3561b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 6; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_tail_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py new file mode 100644 index 0000000000..cce2c778b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py new file mode 100644 index 0000000000..62b6de2d6e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 16 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + for row in range(ROWS): + row_sum = np.sum(src[row, :], dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.float16(row_sum) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto new file mode 100644 index 0000000000..b909f1f66c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_truncf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<8xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> + pto.vmi.group_store %sum16, %ub_dst[%c0], %c16 {num_groups = 8} + : !pto.vmi.vreg<8xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp new file mode 100644 index 0000000000..bd0c1e4fa2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_truncf_store_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_s64_truncf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp new file mode 100644 index 0000000000..941a7d4622 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 512; + constexpr size_t kDstElems = 128; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_truncf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py new file mode 100644 index 0000000000..edcf881e8d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py new file mode 100644 index 0000000000..7e57da8318 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +S16 = 16 +S32 = 32 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(rows: int, cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((rows, cols), dtype=np.float32) + for row in range(rows): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def generate(output_dir: Path) -> None: + src16 = fill_matrix(ROWS, S16, -0.75, 0.03125) + src32 = fill_matrix(ROWS, S32, -0.875, 0.0625) + rhs = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + dst16 = np.full(ROWS, SENTINEL, dtype=np.float32) + dst32 = np.full(ROWS, SENTINEL, dtype=np.float32) + + golden16 = np.sum(src16, axis=1, dtype=np.float32).astype(np.float32) + rhs + golden32 = np.sum(src32, axis=1, dtype=np.float32).astype(np.float32) + rhs + + output_dir.mkdir(parents=True, exist_ok=True) + src16.reshape(-1).tofile(output_dir / "v1.bin") + src32.reshape(-1).tofile(output_dir / "v2.bin") + rhs.tofile(output_dir / "v3.bin") + dst16.tofile(output_dir / "v4.bin") + dst32.tofile(output_dir / "v5.bin") + golden16.astype(np.float32).tofile(output_dir / "golden_v4.bin") + golden32.astype(np.float32).tofile(output_dir / "golden_v5.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto new file mode 100644 index 0000000000..6cbe6b01fc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_slot_add_store_kernel(%src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %rhs16 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %out16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %out16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %rhs32 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %out32 = pto.vmi.addf %sum32, %rhs32 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %out32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp new file mode 100644 index 0000000000..ba7b786e51 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_slot_add_store_kernel(__gm__ float *src16, + __gm__ float *src32, + __gm__ float *rhs, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream) { + vmi_group_reduce_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src16, (__gm__ float *)src32, (__gm__ float *)rhs, + (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp new file mode 100644 index 0000000000..111426c192 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kS16 = 16; + constexpr size_t kS32 = 32; + constexpr size_t kSrc16Elems = kRows * kS16; + constexpr size_t kSrc32Elems = kRows * kS32; + constexpr size_t kRhsElems = kRows; + constexpr size_t kOutputElems = kRows; + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src16Host = nullptr; + float *src32Host = nullptr; + float *rhsHost = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *rhsDevice = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v2.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v3.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v4.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_slot_add_store_kernel( + src16Device, src32Device, rhsDevice, dst16Device, dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst16Host, dstBytes); + WriteFile("./v5.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(rhsDevice); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py new file mode 100644 index 0000000000..60aeab3da6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v3") or not check("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py new file mode 100644 index 0000000000..fa1fc04fe6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + rhs = np.linspace(-0.375, 0.5, ROWS, dtype=np.float32) + dst_reduce = np.full(ROWS, SENTINEL, dtype=np.float32) + dst_slot = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_reduce = np.empty(ROWS, dtype=np.float32) + golden_slot = rhs + rhs + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden_reduce[row] = np.sum(values, dtype=np.float32) + rhs[row] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst_reduce.tofile(output_dir / "v3.bin") + dst_slot.tofile(output_dir / "v4.bin") + golden_reduce.tofile(output_dir / "golden_v3.bin") + golden_slot.astype(np.float32).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto new file mode 100644 index 0000000000..5d66d5ff13 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_cf_join_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_reduce_gm: !pto.ptr, + %dst_slot_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_reduce = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst_slot = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond_true = arith.cmpi eq, %c0, %c0 : index + %cond_false = arith.cmpi ne, %c0, %c0 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + + %reduce_join = scf.if %cond_true -> !pto.vmi.vreg<8xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + scf.yield %sum : !pto.vmi.vreg<8xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %slot : !pto.vmi.vreg<8xf32> + } + %bias0 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %reduce_out = pto.vmi.addf %reduce_join, %bias0 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %reduce_out, %ub_dst_reduce[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + + %slot_join = scf.if %cond_false -> !pto.vmi.vreg<8xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + scf.yield %sum : !pto.vmi.vreg<8xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %slot : !pto.vmi.vreg<8xf32> + } + %bias1 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %slot_out = pto.vmi.addf %slot_join, %bias1 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %slot_out, %ub_dst_slot[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_reduce, %dst_reduce_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst_slot, %dst_slot_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp new file mode 100644 index 0000000000..add61550a6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_cf_join_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dstReduce, + __gm__ float *dstSlot); + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream) { + vmi_group_slots_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dstReduce, + (__gm__ float *)dstSlot); +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp new file mode 100644 index 0000000000..fb8d6ace69 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kOutputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstReduceHost = nullptr; + float *dstSlotHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstReduceDevice = nullptr; + float *dstSlotDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstReduceHost), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstSlotHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstReduceDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstSlotDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstReduceHost, dstBytes); + ReadFile("./v4.bin", dstBytes, dstSlotHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstReduceDevice, dstBytes, dstReduceHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstSlotDevice, dstBytes, dstSlotHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_cf_join_store_kernel(srcDevice, rhsDevice, + dstReduceDevice, dstSlotDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstReduceHost, dstBytes, dstReduceDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dstSlotHost, dstBytes, dstSlotDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstReduceHost, dstBytes); + WriteFile("./v4.bin", dstSlotHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstReduceDevice); + aclrtFree(dstSlotDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstReduceHost); + aclrtFreeHost(dstSlotHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py new file mode 100644 index 0000000000..49180d97de --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2") or not check("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py new file mode 100644 index 0000000000..146d0d1fd2 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden_sum[row] = reduction + golden_out[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto new file mode 100644 index 0000000000..636db6de38 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_fanout_store_broadcast_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %ysum, %ub_out[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp new file mode 100644 index 0000000000..9a0667aae1 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_fanout_store_broadcast_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *out); + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream) { + vmi_group_slots_fanout_store_broadcast_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)out); +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp new file mode 100644 index 0000000000..f7b0fee4b8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kOutElems = 8; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kOutElems * sizeof(float); + size_t outBytes = kOutElems * sizeof(float); + float *srcHost = nullptr; + float *sumHost = nullptr; + float *outHost = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_fanout_store_broadcast_kernel(srcDevice, sumDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py new file mode 100644 index 0000000000..a62c83071c --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + init = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + base = np.linspace(-0.75, 0.25, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = init + np.float32(2.0) * np.sum(src, axis=1, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + init.tofile(output_dir / "v1.bin") + src.reshape(-1).tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto new file mode 100644 index 0000000000..f0e6dc5e25 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_scf_for_store_kernel(%init_gm: !pto.ptr, + %src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_init = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %init_gm, %ub_init, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %acc0 = pto.vmi.group_slot_load %ub_init[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<8xf32>) { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %next : !pto.vmi.vreg<8xf32> + } + pto.vmi.group_store %acc, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp new file mode 100644 index 0000000000..6837a88fd4 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_scf_for_store_kernel(__gm__ float *init, __gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream) { + vmi_group_slots_scf_for_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)init, (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp new file mode 100644 index 0000000000..555d105f43 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 16; + constexpr size_t kInitElems = kRows; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kDstElems = kRows; + size_t initBytes = kInitElems * sizeof(float); + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(float); + float *initHost = nullptr; + float *srcHost = nullptr; + float *dstHost = nullptr; + float *initDevice = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&initHost), initBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&initDevice, initBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", initBytes, initHost, initBytes); + ReadFile("./v2.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(initDevice, initBytes, initHost, initBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_scf_for_store_kernel(initDevice, srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(initDevice); + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(initHost); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md new file mode 100644 index 0000000000..dd5df9a0bc --- /dev/null +++ b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md @@ -0,0 +1,237 @@ + + +# VMI Kernels 的 CCE Case 范围 + +本文档从目标仓库 `.work/external/a5-kernel-standalone/cce` 的 raw CCE 测试入口列出 +VMI kernel 迁移范围。当前审计快照为本地 clone 的 +`main@ee81c3660d6336ecaecd805f02ffb2d69446984e`。不要从当前 +`test/vpto/cases/vmi/kernels` 已存在目录反推目标范围:历史目录包含额外 probe、 +历史 VMI coverage 和尚未对齐 CCE 数据流的 case。 + +## 统计规则 + +“必须支持”只统计目标仓库中的正确性、等价性和 minimum 测试入口。 +`smoke`、`timing`、`bench`、`debug`、`experiments` 和 bandwidth sweep 只用于补充代表 shape +或性能验证,不自动扩展为必须支持的语义 case。 + +| CCE family | 正确性来源 | 必须支持数量 | 首批 VMI 目标 | 暂缓或非首批 | +| --- | --- | ---: | --- | --- | +| `quant_minimum` | `quant_minimum/test/test_tquant.py` | 4 | 全部 4 个 | 该 suite 无暂缓项 | +| `block_quant` | `block_quant/test/test_equivalence.py` | 7 | 全部 7 个 | 除非 raw CCE 新增正确性入口,否则 HIF8 只算 VMI/compiler probe | +| `dynamic_quant` | `dynamic_quant/test/test_dq_equivalence.py` | 9 | 全部 9 个 | subset wrapper 不新增语义 case | +| `dequant/anti_mx_quant` | `dequant/anti_mx_quant/test/test_equivalence.py` | 16 | 先支持 FP8 case | FP4 输入因 VMI FP4 surface 未设计而暂缓 | +| `block_mx_quant` | `block_mx_quant/test/test_equivalence.py`; `test_cce.py` 是更宽的 smoke/correctness surface | 14 canonical,30 full union | 先支持 canonical FP8/OCP/rint 行 | FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | +| `swiglu_mx_quant` | `swiglu_mx_quant/test/test_equivalence.py` | 14 | 先支持 FP8/OCP/rint f16/bf16 行 | FP4 暂缓;CCE 源码中 `scale_alg=1` CUBLAS 路径异常 | +| `simdvf_per_block_cast` | PTOAS PR #488 | 1 | 支持 16x256 f16 + 4x8 f32 scale -> fp8 cast | 当前只迁移该 PR 提供的确定 shape | +| `tutorial/block_mx_quant` | `tutorial/block_mx_quant/README.md` | 已由 `block_mx_quant` 覆盖 | BF16 FP8 tutorial shape 作为代表覆盖 | tutorial FP4 与主 `block_mx_quant` 共用 FP4 blocker | + +## quant_minimum + +来源:`.work/external/a5-kernel-standalone/cce/quant_minimum/test/test_tquant.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| `mxfp8_32x32_nd` | 必须支持 | +| `mxfp8_32x64_nz` | 必须支持 | +| `int8_sym_64x128_nd` | 必须支持 | +| `int8_asym_64x128_nd` | 必须支持 | + +`test_cycle_match.py` 只对同一组 `MINIMUM_CASES` 做 PTO/CCE cycle 对比,不新增语义 case。 + +## block_quant + +来源:`.work/external/a5-kernel-standalone/cce/block_quant/test/test_equivalence.py`。 +所有 case 都使用 `row_block_size=1`、`col_block_size=128`、`dst_type=292`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| BF16 `(4,128)` | 必须支持 | +| FP16 `(8,128)` | 必须支持 | +| BF16 `(32,128)` | 必须支持 | +| FP16 `(16,256)` | 必须支持 | +| BF16 `(2,128)` | 必须支持 | +| FP16 `(4,256)` | 必须支持 | +| 带 `min_scale` 的 BF16 `(4,128)` | 必须支持 | + +`minimal_test.py` 是 smoke 子集。`test_hardware_scale.py`、`large_shape_correctness.py` +和 bandwidth sweep 主要验证更大 streaming shape;只有当它们暴露新的 VMI memory/layout +约束时,才增加代表性 runtime case。 + +## simdvf per-block cast to FP8 + +来源:PTOAS PR #488。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP16 input `(16,256)`,FP32 scale `(4,8)`,每个 scale 覆盖 4 行 x 32 列,输出 FP8 E4M3 | 必须支持 | + +## dynamic_quant + +来源:`.work/external/a5-kernel-standalone/cce/dynamic_quant/test/test_dq_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| per-token,无 smooth,FP16 `(4,32)` | 必须支持 | +| per-token,无 smooth,FP16 `(16,128)` | 必须支持 | +| per-token,smooth,FP16 `(8,64)` | 必须支持 | +| per-token,smooth,FP16 `(16,128)` | 必须支持 | +| per-channel,FP16 `(128,128)` | 必须支持 | +| per-channel,FP16 `(256,256)` | 必须支持 | +| per-token,无 smooth,BF16 `(4,32)` | 必须支持 | +| per-token,smooth,BF16 `(8,64)` | 必须支持 | +| per-channel,BF16 `(128,128)` | 必须支持 | + +`test_pertoken_only.py`、`test_perchannel_128.py` 和 `test_perchannel_all.py` +只是该表的子集或重新分组。 + +## dequant / anti_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/dequant/anti_mx_quant/test/test_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 -> BF16 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> FP32 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> FP16 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(16,512)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(64,2048)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(1024,2048)` | 代表 large/perf;若 medium 已覆盖相同 lowering,首批 runtime shape 不必加入 | +| FP8 E5M2 -> BF16 `(4,128)` | 首批必须支持 | +| FP8 E5M2 -> BF16 `(16,512)` | 首批必须支持 | +| FP4 E2M1 -> BF16 `(4,64)` | 暂缓 | +| FP4 E2M1 -> BF16 `(16,256)` | 暂缓 | +| FP4 E2M1 -> BF16 `(4096,512)` | 暂缓 | +| FP4 E2M1 -> BF16 `(65536,2048)` | 暂缓 | +| FP4 E1M2 -> BF16 `(4,64)` | 暂缓 | +| FP4 E1M2 -> BF16 `(16,256)` | 暂缓 | +| FP4 E1M2 -> BF16 `(4096,512)` | 暂缓 | +| FP4 E1M2 -> BF16 `(65536,2048)` | 暂缓 | + +这些 FP4 行是真实目标仓库 case,但当前 VMI 尚未定义 logical FP4 packed input lane +或 packed-byte load/store 语义。不要用临时 byte trick 模拟。 + +## block_mx_quant + +canonical 来源:`.work/external/a5-kernel-standalone/cce/block_mx_quant/test/test_equivalence.py`。 +这是 `HW_RESULTS.md` 中报告的默认 14-case 正确性 suite。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 BF16 `(4,128)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 FP16 `(64,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 BF16 `(4,128)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 FP16 `(8,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 BF16 `(256,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E2M1 FP16 `(4,128)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 E2M1 BF16 `(1,2,2)`, `scale_alg=2`, `floor`, `dst_type_max=0` | 暂缓 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=2`, `floor`, `dst_type_max=6` | 暂缓 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=2`, `rint`, `dst_type_max=7` | 暂缓 | +| FP4 E1M2 BF16 `(4,128)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 FP16 `(8,256)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E1M2 BF16 `(4,128)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 tail BF16 `(100,300)`, E2M1, `scale_alg=0`, `rint` | 暂缓 | + +`test_cce.py` 额外枚举完整 small-shape type/rounding union: + +| Surface family | 额外覆盖 | +| --- | --- | +| FP8 OCP | FP16/BF16 x E4M3/E5M2, shape `(4,128)`, `rint` | +| FP4 E2M1 OCP | FP16/BF16 x `rint/round/floor`, shape `(4,128)` | +| FP4 E2M1 DDR | FP16/BF16 x `rint/round/floor`, shape `(4,128)`, `scale_alg=2` | +| FP4 E1M2 OCP | FP16/BF16 x `rint/round/floor`, shape `(4,128)` | + +VMI 实现以 canonical 14-case suite 作为迁移 checklist;`test_cce.py` union +在 FP4 设计完成后作为 surface 完整性 checklist。 + +## swiglu_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/swiglu_mx_quant/test/test_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 BF16 `(4,8)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 FP16 `(64,512)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 BF16 `(4,8)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 FP16 `(128,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 BF16 `(64,512)`, `scale_alg=1`, `rint` | 暂缓;CCE 标记 CUBLAS 路径异常 | +| FP8 E5M2 FP16 `(64,512)`, `scale_alg=1`, `rint` | 暂缓;CCE 标记 CUBLAS 路径异常 | +| FP4 E2M1 BF16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 FP16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 BF16 `(64,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E2M1 BF16 `(4,8)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 E2M1 BF16 `(128,256)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 BF16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 FP16 `(64,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E1M2 BF16 `(4,8)`, `scale_alg=0`, `floor` | 暂缓 | + +`test_smoke.py` 在 shape `(4,8)`、`(64,512)`、`(128,256)`、dtype BF16/FP16、 +FP4/FP8 输出模式和 FP8 `scale_alg=1` 上跑 48 个执行面。它不是正确性 oracle; +只有在等价性 case 覆盖后才使用。 + +`test_constant_input.py` 用于诊断 `(4,8)` 和 `(64,512)` 上 BF16 E4M3 OCP +constant input。它可以支撑 tiny deterministic VMI case,但除非 equivalence suite +新增对应项,否则不应为每个 dtype/output type 建一套并行矩阵。 + +## tutorial / block_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/tutorial/block_mx_quant/README.md`。 + +tutorial kernel 是教学用途,和主 `block_mx_quant` 共享算法:BF16 输入、`scale_alg=0`、 +FP8/FP4 输出、32x32 block scale 和 scale2 interleaving。README 说明 smoke 有 9 个 +BF16 output-type case,cross-check 有 7 个 byte-exact case,但当前快照中没有详细测试文件。 +因此 tutorial 覆盖只作为主 `block_mx_quant` 表的代表 shape,不作为独立 family。 + +## 当前 VMI 目录裁剪结果 + +该目录已裁剪为 target-scoped runtime case。删除的 case 只有在目标仓库新增匹配的正确性入口, +或迁移到独立的非目标 probe suite 后,才应重新引入。 + +当前 `test/vpto/cases/vmi/kernels` 已缩减为 36 个 case 目录。上面的目标 CCE canonical +正确性范围在 `block_mx_quant` 采用 14-case canonical suite 时有 65 行;如果把 +`block_mx_quant/test_cce.py` 作为完整 small-shape surface union 计入,则有 81 行。 +这些数量不能直接和当前支持集比较,因为目标列表仍包含当前 VMI 有意暂缓的 FP4 行。 + +| Area | 当前 VMI 目录数 | 目标 canonical 正确性 | 差异 | +| --- | ---: | ---: | --- | +| `quant_minimum` / `tquant` | 4 | 4 | 对齐 `MINIMUM_CASES` | +| `block_quant` | 7 | 7 | 对齐 `test_equivalence.py` | +| `dynamic_quant` | 9 | 9 | 对齐 `test_dq_equivalence.py` | +| `anti_mx_quant` | 7 | 16 | 保留当前 FP8 目标行;暂缓的 FP4 行不表达 | +| `block_mx_quant` | 4 | 14 canonical / 30 full union | 保留 canonical FP8 目标行;暂缓的 FP4/DDR 和额外 `test_cce.py` union 行不表达 | +| `swiglu_mx_quant` | 4 | 14 | 保留当前 FP8/OCP 目标行;暂缓的 FP4 和异常 CUBLAS 行不表达 | +| `simdvf_per_block_cast` | 1 | 1 | 对齐 PR #488 | +| historical `anti_quant` | 0 | 0 | 已从 target-scoped 目录移除 | +| historical `swiglu_quant` | 0 | 0 | 已从 target-scoped 目录移除 | +| other probe | 0 | 0 | 已从 target-scoped 目录移除 | + +## 当前支持目录清单 + +当前 target-scoped runtime 目录精确包含以下 36 个 VMI case: + +| CCE family | VMI case 目录 | +| --- | --- | +| `quant_minimum` / `tquant` | `tquant-mxfp8-32x32-nd`, `tquant-mxfp8-32x64-nz`, `tquant-int8-sym-64x128`, `tquant-int8-asym-64x128` | +| `block_quant` | `block-quant-bf16-fp8-2x128`, `block-quant-bf16-fp8-4x128`, `block-quant-bf16-fp8-4x128-min-scale`, `block-quant-bf16-fp8-32x128`, `block-quant-f16-fp8-4x256`, `block-quant-f16-fp8-8x128`, `block-quant-f16-fp8-16x256` | +| `dynamic_quant` | `dynamic-quant-pertoken-f16-4x32`, `dynamic-quant-pertoken-f16-16x128`, `dynamic-quant-pertoken-smooth-f16-8x64`, `dynamic-quant-pertoken-smooth-f16-16x128`, `dynamic-quant-perchannel-f16-128x128`, `dynamic-quant-perchannel-f16-256x256`, `dynamic-quant-pertoken-bf16-4x32`, `dynamic-quant-pertoken-smooth-bf16-8x64`, `dynamic-quant-perchannel-bf16-128x128` | +| `dequant/anti_mx_quant` | `anti-mx-f8-bf16-scaled-4x128`, `anti-mx-f8-f32-scaled-4x128`, `anti-mx-f8-f16-scaled-4x128`, `anti-mx-f8-bf16-scaled-16x512`, `anti-mx-f8-bf16-scaled-64x2048`, `anti-mx-f8e5m2-bf16-scaled-4x128`, `anti-mx-f8e5m2-bf16-scaled-16x512` | +| `block_mx_quant` | `block-mx-quant-bf16-e4m3-4x128`, `block-mx-quant-f16-e4m3-64x256`, `block-mx-quant-bf16-e5m2-4x128`, `block-mx-quant-f16-e5m2-8x256` | +| `swiglu_mx_quant` | `swiglu-mx-quant-bf16-e4m3-4x8`, `swiglu-mx-quant-f16-e4m3-64x512`, `swiglu-mx-quant-bf16-e5m2-4x8`, `swiglu-mx-quant-f16-e5m2-128x256` | +| `simdvf_per_block_cast` | `simdvf-per-block-cast-to-fp8` | + +| 已移除的 VMI 区域 | 范围说明 | +| --- | --- | +| 额外 `anti-mx` FP8 E5M2 -> FP16/FP32 large-shape case | 对称 decode 覆盖;未列入目标 `anti_mx_quant/test_equivalence.py` | +| 额外 `dynamic_quant` BF16 larger smooth/no-smooth case | 实现扩展覆盖;不在 9-case 目标 equivalence 列表中 | +| 额外 `block_mx_quant` random/shared-scale case | 对独立 golden 覆盖有用;不是直接目标测试名 | +| 额外 `block_mx_quant` FP16 `(4,128)` E4M3/E5M2 行 | 只存在于更宽的 `test_cce.py` union;不是 canonical equivalence checklist 的一部分 | +| 额外 `swiglu_mx_quant` constant BF16 4x8 proxy | 诊断输入模式;除非迁移为精确 equivalence 行,否则不保留在 target-scoped runtime case 中 | +| HIF8 `block_quant` probe | compiler/runtime surface probe;不是该目标仓库中的 raw CCE 正确性 case | diff --git a/test/vpto/cases/vmi/kernels/README.md b/test/vpto/cases/vmi/kernels/README.md new file mode 100644 index 0000000000..cab387fdb4 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/README.md @@ -0,0 +1,50 @@ + + +# VMI Kernel 用例 + +本目录只保留从目标仓库 `.work/external/a5-kernel-standalone/cce` 的 raw CCE +正确性、等价性和 minimum 测试入口迁移而来的 VMI runtime case。 +范围定义见 [CCE_CASE_SCOPE.md](CCE_CASE_SCOPE.md)。 + +不要从历史 VMI probe、benchmark sweep、debug shape、random stress 或当前目录外的 +实验脚本反推支持范围。新增 case 前先确认目标 CCE 测试入口是否提供对应正确性语义。 + +## 当前目录范围 + +当前目录保留 36 个 runtime case: + +| CCE family | 当前 case 数 | 范围 | +| --- | ---: | --- | +| `quant_minimum` / `tquant` | 4 | 对齐 `MINIMUM_CASES` | +| `block_quant` | 7 | 对齐 `test_equivalence.py` | +| `dynamic_quant` | 9 | 对齐 `test_dq_equivalence.py` | +| `dequant/anti_mx_quant` | 7 | 当前保留 VMI 能表达的 FP8 行;FP4 输入暂缓 | +| `block_mx_quant` | 4 | 当前保留 canonical FP8/OCP 等价性行;FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | +| `swiglu_mx_quant` | 4 | 当前保留 FP8/OCP 等价性行;FP4 和 CCE 已标记异常的 CUBLAS `scale_alg=1` 暂缓 | +| `simdvf_per_block_cast` | 1 | 对齐 PTOAS PR #488 中的 16x256 f16 + 4x8 scale -> fp8 per-block cast case | + +## 设计上暂缓 + +下列目标 CCE 行是真实存在的,但在对应 VMI 语义设计清楚前,不应通过临时拼凑的 +runtime case 表达: + +| Case 类别 | 原因 | +| --- | --- | +| FP4 packed input/output | VMI 尚未定义 logical FP4 lane、packed-byte layout 和 FP4 load/store 语义 | +| `block_mx_quant` FP4 DDR `scale_alg=2` | 依赖 FP4 语义和 DDR scale 规则 | +| `swiglu_mx_quant` FP8 CUBLAS `scale_alg=1` | CCE 源码已标记该路径异常 | +| HIF8 `block_quant` | 只是 compiler/runtime surface probe,不是目标仓库里的 raw CCE 正确性 case | + +## 验证策略 + +每个保留的 runtime case 都应通过 `test/vpto/scripts/run_host_vpto_validation.sh`。 +新增 CCE 迁移 case 时,先在 [CCE_CASE_SCOPE.md](CCE_CASE_SCOPE.md) 记录对应的 +目标 CCE 来源行。除非新 case 覆盖不同的 VMI 语义或 lowering 约束,否则不要重复添加同构 shape。 diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py new file mode 100644 index 0000000000..09a91b695c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 512 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], + dtype=np.float32, +) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto new file mode 100644 index 0000000000..391687c3a5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_16x512_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c32_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c512 : index + %row_scale_off = arith.muli %row, %c64 : index + scf.for %col = %c0 to %c512 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp new file mode 100644 index 0000000000..bb37bcb8fa --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_bf16_scaled_16x512_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_bf16_scaled_16x512_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp new file mode 100644 index 0000000000..b9080bacc5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 512; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py new file mode 100644 index 0000000000..b714f6d4db --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..f0def9ef3f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..8fd6bdeaab --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_bf16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_bf16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..414bb6bd53 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py new file mode 100644 index 0000000000..abce51a142 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 2048 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], + dtype=np.float32, +) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto new file mode 100644 index 0000000000..2e73913362 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_64x2048_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + %c32768 = arith.constant 32768 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c65536_i64 = arith.constant 65536 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c65536_i64 : i64 -> !pto.ptr + + scf.for %tile = %c0 to %c4 step %c1 { + %src_tile_off = arith.muli %tile, %c32768 : index + %scale_tile_base_off = arith.muli %tile, %c1024 : index + %dst_tile_off = arith.muli %tile, %c32768 : index + %src_tile_gm = pto.addptr %src_gm, %src_tile_off + : !pto.ptr -> !pto.ptr + %scale_tile_gm = pto.addptr %mxscale_gm, %scale_tile_base_off + : !pto.ptr -> !pto.ptr + %dst_tile_gm = pto.addptr %dst_gm, %dst_tile_off + : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %src_tile_gm, %ub_src_u8, %c0_i64, %c2048_i64 + nburst(%c16_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_tile_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c128_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c2048 : index + %row_scale_off = arith.muli %row, %c256 : index + scf.for %col = %c0 to %c2048 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_tile_gm, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp new file mode 100644 index 0000000000..3768c6fd0c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_bf16_scaled_64x2048_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_bf16_scaled_64x2048_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp new file mode 100644 index 0000000000..068ff83a6b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 2048; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py new file mode 100644 index 0000000000..d5dcd9b576 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] f16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py new file mode 100644 index 0000000000..fc1502fce0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_F16 = np.uint16(0x7E00) + + +def f32_to_f16_bits(values: np.ndarray) -> np.ndarray: + return values.astype(np.float16).view(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_F16, dtype=np.uint16) + golden = f32_to_f16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..d346c28a3c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_f16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..8f8b341af5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_f16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ half *dst); + +void LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_f16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..c357020a90 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py new file mode 100644 index 0000000000..cc852efbc2 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + out = np.fromfile("v3.bin", dtype=np.float32) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = golden[idx] if idx >= 0 else "n/a" + out_value = out[idx] if idx >= 0 else "n/a" + print( + f"[ERROR] f32 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py new file mode 100644 index 0000000000..d453724750 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_F32 = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + golden = scaled.reshape(-1).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..e05ab49cc7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_f32_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %scaled, %ub_dst[%offset] + : !pto.vmi.vreg<256xf32>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..66a052ad5c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_f32_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ float *dst); + +void LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + float *dst, + void *stream) { + vmi_anti_mx_f8_f32_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp new file mode 100644 index 0000000000..839873e18a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(float); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + float *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py new file mode 100644 index 0000000000..77fed89e97 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 512 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E5M2_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto new file mode 100644 index 0000000000..1d31ac88ac --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c32_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c512 : index + %row_scale_off = arith.muli %row, %c64 : index + scf.for %col = %c0 to %c512 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E5M2> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E5M2> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp new file mode 100644 index 0000000000..08d84b318d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp new file mode 100644 index 0000000000..5700dec3a9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 512; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py new file mode 100644 index 0000000000..e3bc0b6db1 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E5M2_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..4907537215 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E5M2> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E5M2> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..0e2fdf35e7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..9125352fc8 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py new file mode 100644 index 0000000000..0b646c2c73 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE1_BYTES = 16 +SCALE2_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(f32_to_bf16_bits(q_row / np.float32(256.0)), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x77), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x77) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto new file mode 100644 index 0000000000..48d6c0d146 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto @@ -0,0 +1,148 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_bf16_e4m3_4x128_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c4 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_slot = arith.divui %row, %c2 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c119_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c2_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp new file mode 100644 index 0000000000..462595a2f6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_bf16_e4m3_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_bf16_e4m3_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp new file mode 100644 index 0000000000..760e76a1c9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 16; + constexpr size_t kScale2Bytes = 256; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py new file mode 100644 index 0000000000..e8c20fb5fe --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE1_BYTES = 16 +SCALE2_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(f32_to_bf16_bits(q_row), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x7F), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x7F) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto new file mode 100644 index 0000000000..4d42e59c69 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto @@ -0,0 +1,148 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_bf16_e5m2_4x128_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c4 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_slot = arith.divui %row, %c2 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c2_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp new file mode 100644 index 0000000000..35bd3e233c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_bf16_e5m2_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_bf16_e5m2_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp new file mode 100644 index 0000000000..8fc855e6b7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 16; + constexpr size_t kScale2Bytes = 256; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py new file mode 100644 index 0000000000..5dc8e5bcce --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 256 +SCALE1_BYTES = 512 +SCALE2_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile((q_row / np.float32(256.0)).astype(np.float16), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x77), dtype=np.uint8) + golden_scale2 = np.full(SCALE2_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto new file mode 100644 index 0000000000..a71ca57db9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_f16_e4m3_64x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c64 step %c1 { + %elem_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %scale2_i32 = pto.vmi.broadcast %c119_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + pto.vmi.store %scale2_u8, %ub_scale2[%c256] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c64_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp new file mode 100644 index 0000000000..642c2b33d9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_f16_e4m3_64x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_f16_e4m3_64x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp new file mode 100644 index 0000000000..ec205e289a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 512; + constexpr size_t kScale2Bytes = 512; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py new file mode 100644 index 0000000000..27efbace62 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 256 +SCALE1_BYTES = 64 +SCALE2_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row.astype(np.float16), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x7F), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x7F) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto new file mode 100644 index 0000000000..d8c2377acc --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_f16_e5m2_8x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c8 step %c1 { + %elem_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + pto.vmi.store %scale2_u8, %ub_scale2[%c256] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c8_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp new file mode 100644 index 0000000000..1b139bdee6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_f16_e5m2_8x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_f16_e5m2_8x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp new file mode 100644 index 0000000000..f5932ec784 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 64; + constexpr size_t kScale2Bytes = 512; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py new file mode 100644 index 0000000000..e139f87144 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 2 +COLS = 128 +SCALE_SLOTS = ROWS +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto new file mode 100644 index 0000000000..21664de067 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_2x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c8_i64 + nburst(%c1_i64, %c8_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp new file mode 100644 index 0000000000..a1bb355958 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_2x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_2x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_2x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp new file mode 100644 index 0000000000..632018a6ff --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_2x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 2; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_2x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py new file mode 100644 index 0000000000..a166b6aa4a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 128 +SCALE_SLOTS = ROWS +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto new file mode 100644 index 0000000000..551f06b94d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_32x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp new file mode 100644 index 0000000000..24599cfec9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_32x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_32x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_32x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp new file mode 100644 index 0000000000..cd5211e167 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_32x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_32x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py new file mode 100644 index 0000000000..e4e1f95824 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALE_LIMIT = np.float32(0.25) +SCALES = np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [ + 0.0, + 224.0, + -224.0, + 112.0, + -112.0, + 64.0, + -64.0, + 32.0, + -32.0, + 16.0, + -16.0, + 8.0, + -8.0, + 4.0, + -4.0, + 2.0, + -2.0, + 1.0, + -1.0, + 0.5, + -0.5, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def decode_f8e4m3fn(byte: int) -> np.float32: + sign = -1.0 if byte & 0x80 else 1.0 + exp = (byte >> 3) & 0x0F + mant = byte & 0x07 + if byte in (0x7F, 0xFF): + return np.float32(np.nan) + if exp == 0: + return np.float32(sign * (mant / 8.0) * (2.0**-6)) + return np.float32(sign * (1.0 + mant / 8.0) * (2.0 ** (exp - 7))) + + +def f8e4m3fn_exact_bytes(values: np.ndarray) -> np.ndarray: + exact = {} + for byte in range(0x100): + decoded = decode_f8e4m3fn(byte) + if not np.isnan(decoded): + exact.setdefault(np.float32(decoded).item(), byte) + return np.array([exact[np.float32(value).item()] for value in values], dtype=np.uint8) + + +def f8e4m3fn_saturating_bytes(values: np.ndarray) -> np.ndarray: + clipped = np.clip(values.astype(np.float32), -FP8_MAX, FP8_MAX) + return f8e4m3fn_exact_bytes(clipped) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + raw_scale = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + scale = np.minimum(raw_scale, SCALE_LIMIT).astype(np.float32) + golden_scale[row] = scale + golden_out[row] = f8e4m3fn_saturating_bytes(x_f32 / scale) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto new file mode 100644 index 0000000000..cc6456ea14 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_4x128_min_scale_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + %scale_limit = arith.constant 2.500000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale_raw = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + %scale_limit_v = pto.vmi.broadcast %scale_limit + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.minf %scale_raw, %scale_limit_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c256] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale_raw = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + %scale_limit_v = pto.vmi.broadcast %scale_limit + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.minf %scale_raw, %scale_limit_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c256] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp new file mode 100644 index 0000000000..cf7c8a8aa5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_4x128_min_scale_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_4x128_min_scale_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp new file mode 100644 index 0000000000..8218bd476e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py new file mode 100644 index 0000000000..136c639dde --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto new file mode 100644 index 0000000000..b9a9b04a78 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_4x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c256] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c256] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp new file mode 100644 index 0000000000..18de0443fd --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_4x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp new file mode 100644 index 0000000000..b021bda736 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_4x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_4x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py new file mode 100644 index 0000000000..934d197785 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 256 +BLOCK_COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (BLOCK_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:BLOCK_COLS].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:BLOCK_COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + for block in range(COLS // BLOCK_COLS): + group = row * (COLS // BLOCK_COLS) + block + start = block * BLOCK_COLS + stop = start + BLOCK_COLS + src[row, start:stop] = (q_block * SCALES[group]).astype(np.float16) + x_f32 = src[row, start:stop].astype(np.float32) + golden_scale[group] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row, start:stop] = f8_block + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto new file mode 100644 index 0000000000..5655f47848 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_16x256_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c16 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp new file mode 100644 index 0000000000..9aaa904f48 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_16x256_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_16x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_16x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp new file mode 100644 index 0000000000..677cf3c033 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_f16_fp8_16x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_16x256_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py new file mode 100644 index 0000000000..14e3870858 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 256 +BLOCK_COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 2) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (BLOCK_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:BLOCK_COLS].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:BLOCK_COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + for block in range(COLS // BLOCK_COLS): + group = row * (COLS // BLOCK_COLS) + block + start = block * BLOCK_COLS + stop = start + BLOCK_COLS + src[row, start:stop] = (q_block * SCALES[group]).astype(np.float16) + x_f32 = src[row, start:stop].astype(np.float32) + golden_scale[group] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row, start:stop] = f8_block + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto new file mode 100644 index 0000000000..8362807227 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_4x256_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c4 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp new file mode 100644 index 0000000000..123cfafdc5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_4x256_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_4x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_4x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp new file mode 100644 index 0000000000..74c90073ae --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_f16_fp8_4x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_4x256_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py new file mode 100644 index 0000000000..6345d8a069 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 2) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = (q_row * SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto new file mode 100644 index 0000000000..ad3fa32760 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_8x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c4 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp new file mode 100644 index 0000000000..75929341b7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_8x128_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_8x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_8x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp new file mode 100644 index 0000000000..433abffbd7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_f16_fp8_8x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_8x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py new file mode 100644 index 0000000000..f9bc779e12 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = f32_to_bf16_bits(q[:, None] * col_scales[None, :]) + x_f32 = bf16_bits_to_f32(src) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto new file mode 100644 index 0000000000..81d857b8f6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto @@ -0,0 +1,188 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_bf16_128x128_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c102400_i64 = arith.constant 102400 : i64 + %c106496_i64 = arith.constant 106496 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scratch = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale_padded = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c102400_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c106496_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c128 step %c1 { + %elem_offset = arith.muli %row, %c128 : index + %x_bf16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<128xbf16> + %x_f32 = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<128xbf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x_f32, %ub_scratch[%elem_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + scf.for %col = %c0 to %c128 step %c1 { + %col_i32 = arith.index_cast %col : index to i32 + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %init = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %stride = pto.vmi.broadcast %c128_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %row_offsets = pto.vmi.muli %lane, %stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_vec = pto.vmi.broadcast %col_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %indices = pto.vmi.addi %row_offsets, %col_vec + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_values = pto.vmi.gather %ub_scratch[%indices], %mask, %zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %col_values + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.reduce_maxf %abs, %init, %mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %padded_col = arith.muli %col, %c8 : index + pto.vmi.store %scale, %ub_scale_padded[%padded_col] + : !pto.vmi.vreg<1xf32>, !pto.ptr + } + + %scale_mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %scale_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %scale_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_indices = pto.vmi.muli %scale_lane, %scale_stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %scale_dense = pto.vmi.gather %ub_scale_padded[%scale_indices], %scale_mask, %scale_zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %scale_dense, %ub_scale[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %scale_pair_mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %scale_pair_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_pair_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_group_mask = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_col = pto.vmi.andi %scale_pair_lane, %scale_pair_group_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_pair_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_indices = pto.vmi.muli %scale_pair_col, %scale_pair_stride + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_vec = pto.vmi.gather %ub_scale_padded[%scale_pair_indices], %scale_pair_mask, %scale_pair_zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.for %pair = %c0 to %c64 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_scratch[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp new file mode 100644 index 0000000000..08a9931a37 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_bf16_128x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_bf16_128x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp new file mode 100644 index 0000000000..451bb794b6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py new file mode 100644 index 0000000000..a249eec0ee --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = (q[:, None] * col_scales[None, :]).astype(np.float16) + x_f32 = src.astype(np.float32) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto new file mode 100644 index 0000000000..0bcf64283d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto @@ -0,0 +1,188 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_f16_128x128_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c102400_i64 = arith.constant 102400 : i64 + %c106496_i64 = arith.constant 106496 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scratch = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale_padded = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c102400_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c106496_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c128 step %c1 { + %elem_offset = arith.muli %row, %c128 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x_f32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x_f32, %ub_scratch[%elem_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + scf.for %col = %c0 to %c128 step %c1 { + %col_i32 = arith.index_cast %col : index to i32 + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %init = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %stride = pto.vmi.broadcast %c128_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %row_offsets = pto.vmi.muli %lane, %stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_vec = pto.vmi.broadcast %col_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %indices = pto.vmi.addi %row_offsets, %col_vec + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_values = pto.vmi.gather %ub_scratch[%indices], %mask, %zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %col_values + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.reduce_maxf %abs, %init, %mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %padded_col = arith.muli %col, %c8 : index + pto.vmi.store %scale, %ub_scale_padded[%padded_col] + : !pto.vmi.vreg<1xf32>, !pto.ptr + } + + %scale_mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %scale_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %scale_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_indices = pto.vmi.muli %scale_lane, %scale_stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %scale_dense = pto.vmi.gather %ub_scale_padded[%scale_indices], %scale_mask, %scale_zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %scale_dense, %ub_scale[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %scale_pair_mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %scale_pair_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_pair_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_group_mask = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_col = pto.vmi.andi %scale_pair_lane, %scale_pair_group_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_pair_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_indices = pto.vmi.muli %scale_pair_col, %scale_pair_stride + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_vec = pto.vmi.gather %ub_scale_padded[%scale_pair_indices], %scale_pair_mask, %scale_pair_zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.for %pair = %c0 to %c64 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_scratch[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp new file mode 100644 index 0000000000..abec34be2e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_f16_128x128_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_f16_128x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp new file mode 100644 index 0000000000..741d5e3d9e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py new file mode 100644 index 0000000000..dc9933534f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 256 +COLS = 256 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = (q[:, None] * col_scales[None, :]).astype(np.float16) + x_f32 = src.astype(np.float32) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto new file mode 100644 index 0000000000..89507a9228 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_f16_256x256_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c135168_i64 = arith.constant 135168 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c135168_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c131072_i64 + nburst(%c1_i64, %c131072_i64, %c131072_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %zero_acc = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_acc = scf.for %row = %c0 to %c256 step %c1 + iter_args(%acc = %zero_acc) -> (!pto.vmi.vreg<256xf32>) { + %elem_offset = arith.muli %row, %c256 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %next = pto.vmi.maxf %acc, %abs + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.yield %next : !pto.vmi.vreg<256xf32> + } + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %scale_acc, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %scale, %ub_scale[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + scf.for %row = %c0 to %c256 step %c1 { + %elem_offset = arith.muli %row, %c256 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp new file mode 100644 index 0000000000..a805d4f3f4 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_f16_256x256_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_f16_256x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp new file mode 100644 index 0000000000..0b840eff7a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 256; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py new file mode 100644 index 0000000000..dc44fd67ba --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 32 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * ROW_SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto new file mode 100644 index 0000000000..80172d52e1 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_bf16_4x32_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_bf16 = arith.constant 1.000000e+00 : bf16 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %pad_bf16 = pto.vmi.broadcast %c1_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %pad_bf16, %ub_src[%c128] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.mem_bar "VST_VLD" + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<8xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp new file mode 100644 index 0000000000..f8514dfdb6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_bf16_4x32_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_bf16_4x32_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp new file mode 100644 index 0000000000..f11096d3e6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py new file mode 100644 index 0000000000..e17afddf0a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto new file mode 100644 index 0000000000..babe48feb5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_f16_16x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c8 step %c1 { + %row = arith.muli %tile, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp new file mode 100644 index 0000000000..5f63392588 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_f16_16x128_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_f16_16x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp new file mode 100644 index 0000000000..1385903f12 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py new file mode 100644 index 0000000000..351c27ef75 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 32 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto new file mode 100644 index 0000000000..761ff9be84 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_f16_4x32_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_f16 = arith.constant 1.000000e+00 : f16 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %pad_f16 = pto.vmi.broadcast %c1_f16 + : f16 -> !pto.vmi.vreg<256xf16> + pto.vmi.store %pad_f16, %ub_src[%c128] + : !pto.vmi.vreg<256xf16>, !pto.ptr + pto.mem_bar "VST_VLD" + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<8xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp new file mode 100644 index 0000000000..4c4675e167 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_f16_4x32_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_f16_4x32_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp new file mode 100644 index 0000000000..e0616d1d83 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py new file mode 100644 index 0000000000..45952cc50d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 64 +SCALE_SLOTS = 16 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = f32_to_bf16_bits(np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS]) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * ROW_SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) * bf16_bits_to_f32(smooth) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 4) * 8 + (row % 4)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + smooth.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto new file mode 100644 index 0000000000..35b641663a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto @@ -0,0 +1,156 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth1, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth3, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c2 step %c1 { + %row = arith.muli %tile, %c4 : index + %elem_offset = arith.muli %row, %c64 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %smooth_bf16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %smooth = pto.vmi.extf %smooth_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<4xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<4xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<4xf32>, !pto.vmi.vreg<4xf32> + -> !pto.vmi.vreg<4xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} + : !pto.vmi.vreg<4xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} + : !pto.vmi.vreg<4xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp new file mode 100644 index 0000000000..c601fb3637 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel(__gm__ bfloat16_t *src, + __gm__ bfloat16_t *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ bfloat16_t *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp new file mode 100644 index 0000000000..ec2867832a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 16; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py new file mode 100644 index 0000000000..a0592fd234 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 128 +SCALE_SLOTS = 64 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS].astype(np.float16) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) * smooth.astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 2) * 8 + (row % 2)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + smooth.view(np.uint16).tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto new file mode 100644 index 0000000000..2fd9708cc9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto @@ -0,0 +1,160 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + %ub_smooth4 = pto.castptr %c12800_i64 : i64 -> !pto.ptr + %ub_smooth5 = pto.castptr %c12928_i64 : i64 -> !pto.ptr + %ub_smooth6 = pto.castptr %c13056_i64 : i64 -> !pto.ptr + %ub_smooth7 = pto.castptr %c13184_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth4, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth6, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c8 step %c1 { + %row = arith.muli %tile, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smooth_f16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %smooth = pto.vmi.extf %smooth_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<2xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<2xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 2} + : !pto.vmi.vreg<2xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp new file mode 100644 index 0000000000..0a95b0c6dd --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel(__gm__ half *src, + __gm__ half *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp new file mode 100644 index 0000000000..d17c9c9d3f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 64; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py new file mode 100644 index 0000000000..df65413c04 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 64 +SCALE_SLOTS = 16 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS].astype(np.float16) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) * smooth.astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 4) * 8 + (row % 4)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + smooth.view(np.uint16).tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto new file mode 100644 index 0000000000..37d6ec43b2 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto @@ -0,0 +1,156 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth1, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth3, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c2 step %c1 { + %row = arith.muli %tile, %c4 : index + %elem_offset = arith.muli %row, %c64 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smooth_f16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %smooth = pto.vmi.extf %smooth_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<4xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<4xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<4xf32>, !pto.vmi.vreg<4xf32> + -> !pto.vmi.vreg<4xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} + : !pto.vmi.vreg<4xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} + : !pto.vmi.vreg<4xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp new file mode 100644 index 0000000000..851f3282a0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel(__gm__ half *src, + __gm__ half *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp new file mode 100644 index 0000000000..a65f375355 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 16; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py new file mode 100644 index 0000000000..2d18033478 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py new file mode 100644 index 0000000000..22ff7c2109 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 256 +SCALE_ROWS = 4 +SCALE_COLS = 8 +TOKENS_PER_SCALE_ROW = 4 +CHANNELS_PER_SCALE = 32 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + scale = np.array( + [ + [0.25, 0.5, 1.0, 2.0, 0.25, 0.5, 1.0, 2.0], + [0.5, 1.0, 2.0, 4.0, 0.5, 1.0, 2.0, 4.0], + [1.0, 2.0, 4.0, 0.25, 1.0, 2.0, 4.0, 0.25], + [2.0, 4.0, 0.25, 0.5, 2.0, 4.0, 0.25, 0.5], + ], + dtype=np.float32, + ) + + repeats = (CHANNELS_PER_SCALE + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:CHANNELS_PER_SCALE].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:CHANNELS_PER_SCALE] + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + scale_row = row // TOKENS_PER_SCALE_ROW + for scale_col in range(SCALE_COLS): + start = scale_col * CHANNELS_PER_SCALE + stop = start + CHANNELS_PER_SCALE + src[row, start:stop] = (q_block / scale[scale_row, scale_col]).astype( + np.float16 + ) + golden_out[row, start:stop] = f8_block + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.reshape(-1).tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto new file mode 100644 index 0000000000..0e6c80648e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_simdvf_per_block_cast_to_fp8_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %num_per_tokens = arith.constant 4 : index + %num_sf_rows_per_block = arith.constant 4 : index + %num_per_channels = arith.constant 32 : index + %block_k = arith.muli %c8, %num_per_channels : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %sf_i = %c0 to %num_sf_rows_per_block step %c1 { + %sf_row_offset = arith.muli %sf_i, %c8 : index + %sf_slots = pto.vmi.group_slot_load %ub_scale[%sf_row_offset], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sf = pto.vmi.group_broadcast %sf_slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + scf.for %token_j = %c0 to %num_per_tokens step %c1 { + %sf_token_row_base = arith.muli %sf_i, %c4 : index + %token_row = arith.addi %sf_token_row_base, %token_j : index + %row_elem_offset = arith.muli %token_row, %block_k : index + %x16 = pto.vmi.load %ub_src[%row_elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x32, %sf + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %out8, %ub_out8_f8[%row_elem_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp new file mode 100644 index 0000000000..f053235e1b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_simdvf_per_block_cast_to_fp8_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out8); + +void LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(uint16_t *src, float *scale, + uint8_t *out8, + void *stream) { + vmi_simdvf_per_block_cast_to_fp8_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp new file mode 100644 index 0000000000..2fa1b028d4 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(uint16_t *src, float *scale, + uint8_t *out8, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleRows = 4; + constexpr size_t kScaleCols = 8; + constexpr size_t kScaleElems = kScaleRows * kScaleCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py new file mode 100644 index 0000000000..9c6b791567 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +INPUT_COLS = 8 +OUT_COLS = 4 +SCALE_BYTES = 4 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([1.0, -1.0, 0.5, 448.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x38, 0xB8, 0x30, 0x7E], dtype=np.uint8) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + x2 = f32_to_bf16_bits(Q_VALUES / np.float32(4096.0)) + x1 = f32_to_bf16_bits(np.full(OUT_COLS, np.float32(16.0), dtype=np.float32)) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(F8E4M3FN_BYTES, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto new file mode 100644 index 0000000000..0500e465c8 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto @@ -0,0 +1,155 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_bf16 = arith.constant 0.000000e+00 : bf16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_x2 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_x1 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c128_i64 : i64 -> !pto.ptr + %src_x1_gm = pto.addptr %src_gm, %c4 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %zero_pad, %ub_x2[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.vmi.store %zero_pad, %ub_x1[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + + pto.mte_gm_ub %src_gm, %ub_x2, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + pto.mte_gm_ub %src_x1_gm, %ub_x1, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c4 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %x2_bf16 = pto.vmi.load %ub_x2[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x1_bf16 = pto.vmi.load %ub_x1[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x2 = pto.vmi.extf %x2_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_u8 = pto.vmi.trunci %e8m0_i32 + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xui8> + pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4_i64 + nburst(%c4_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp new file mode 100644 index 0000000000..09fbbaa897 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, + (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp new file mode 100644 index 0000000000..ed9deaa718 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kInputCols = 8; + constexpr size_t kOutCols = 4; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = kRows; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py new file mode 100644 index 0000000000..9634af6120 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +INPUT_COLS = 8 +OUT_COLS = 4 +SCALE_BYTES = 4 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([1.0, -1.0, 0.5, 57344.0], dtype=np.float32) +F8E5M2_BYTES = np.array([0x3C, 0xBC, 0x38, 0x7B], dtype=np.uint8) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + x2 = f32_to_bf16_bits(Q_VALUES / np.float32(16.0)) + x1 = f32_to_bf16_bits(np.full(OUT_COLS, np.float32(16.0), dtype=np.float32)) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(F8E5M2_BYTES, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto new file mode 100644 index 0000000000..8375e83240 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto @@ -0,0 +1,155 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_bf16 = arith.constant 0.000000e+00 : bf16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_x2 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_x1 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c128_i64 : i64 -> !pto.ptr + %src_x1_gm = pto.addptr %src_gm, %c4 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %zero_pad, %ub_x2[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.vmi.store %zero_pad, %ub_x1[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + + pto.mte_gm_ub %src_gm, %ub_x2, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + pto.mte_gm_ub %src_x1_gm, %ub_x1, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c4 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %x2_bf16 = pto.vmi.load %ub_x2[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x1_bf16 = pto.vmi.load %ub_x1[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x2 = pto.vmi.extf %x2_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_u8 = pto.vmi.trunci %e8m0_i32 + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xui8> + pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xui8>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4_i64 + nburst(%c4_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp new file mode 100644 index 0000000000..51b4c50290 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, + (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp new file mode 100644 index 0000000000..bdfc82d090 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kInputCols = 8; + constexpr size_t kOutCols = 4; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = kRows; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py new file mode 100644 index 0000000000..b26ff0646b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +INPUT_COLS = 512 +OUT_COLS = 256 +SCALE_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (OUT_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:OUT_COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:OUT_COLS].astype(np.uint8) + + x2 = (q_row / np.float32(4096.0)).astype(np.float16) + x1 = np.full(OUT_COLS, np.float16(16.0), dtype=np.float16) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto new file mode 100644 index 0000000000..bd0b116a53 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto @@ -0,0 +1,150 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c98304_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c64 step %c1 { + %src_row_off = arith.muli %row, %c512 : index + %x1_off = arith.addi %src_row_off, %c256 : index + %out_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x2_f16 = pto.vmi.load %ub_src[%src_row_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x1_f16 = pto.vmi.load %ub_src[%x1_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x2 = pto.vmi.extf %x2_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%out_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c8_i64 + nburst(%c64_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp new file mode 100644 index 0000000000..32dfb4b472 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp new file mode 100644 index 0000000000..e20dc8e25f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kInputCols = 512; + constexpr size_t kOutCols = 256; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = 512; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py new file mode 100644 index 0000000000..20aea94e4f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +INPUT_COLS = 256 +OUT_COLS = 128 +SCALE_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def generate(output_dir: Path) -> None: + repeats = (OUT_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:OUT_COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:OUT_COLS].astype(np.uint8) + + x2 = (q_row / np.float32(16.0)).astype(np.float16) + x1 = np.full(OUT_COLS, np.float16(16.0), dtype=np.float16) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto new file mode 100644 index 0000000000..86aa4e4a35 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto @@ -0,0 +1,160 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c32768 = arith.constant 32768 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c163840_i64 = arith.constant 163840 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_f16 = arith.constant 0.000000e+00 : f16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c163840_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_f16 + : f16 -> !pto.vmi.vreg<256xf16> + pto.vmi.store %zero_pad, %ub_src[%c32768] + : !pto.vmi.vreg<256xf16>, !pto.ptr + scf.for %row = %c0 to %c128 step %c1 { + %src_row_off = arith.muli %row, %c256 : index + %x1_off = arith.addi %src_row_off, %c128 : index + %out_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x2_f16 = pto.vmi.load %ub_src[%src_row_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x1_f16 = pto.vmi.load %ub_src[%x1_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x2 = pto.vmi.extf %x2_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%out_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c128_i64 + nburst(%c128_i64, %c256_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c128_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp new file mode 100644 index 0000000000..dbfc94e8b9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp new file mode 100644 index 0000000000..cfe997291f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kInputCols = 256; + constexpr size_t kOutCols = 128; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = 512; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py new file mode 100644 index 0000000000..85beaede4b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v4 idx={idx} " + f"golden={int(golden[idx]) if idx >= 0 else 'n/a'} " + f"output={int(output[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py new file mode 100644 index 0000000000..7e0bce8527 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 128 +RNG_SEED = 19 + + +def generate(output_dir: Path) -> None: + np.random.seed(RNG_SEED) + src = np.random.uniform(low=-2, high=2, size=(ROWS, COLS)).astype(np.float32) + row_min = np.min(src, axis=1, keepdims=True) + row_max = np.max(src, axis=1, keepdims=True) + scale = ((row_max - row_min) / np.float32(255.0)).astype(np.float32) + inv_scale = np.where(scale != 0, np.float32(1.0) / scale, np.float32(0.0)).astype(np.float32) + offset = np.clip(np.round(-row_min / scale), 0, 255).astype(np.float32) + rounded = np.round(src * inv_scale + offset).astype(np.float32) + golden = np.clip(rounded.astype(np.float16), 0, 255).astype(np.uint8) + dst = np.full((ROWS, COLS), 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + inv_scale.reshape(ROWS).tofile(output_dir / "v2.bin") + offset.reshape(ROWS).tofile(output_dir / "v3.bin") + dst.tofile(output_dir / "v4.bin") + golden.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto new file mode 100644 index 0000000000..6724cd56d0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_int8_asym_64x128_kernel(%src_gm: !pto.ptr, + %inv_scale_gm: !pto.ptr, + %offset_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c36864_i64 = arith.constant 36864 : i64 + %c40960_i64 = arith.constant 40960 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_inv_scale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_offset = pto.castptr %c36864_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c40960_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %inv_scale_gm, %ub_inv_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %offset_gm, %ub_offset, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %pair = %c0 to %c32 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %row_i32 = arith.index_cast %row : index to i32 + %gather_mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c7_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %local_group = pto.vmi.shrui %lane, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %row_base = pto.vmi.broadcast %row_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %param_indices = pto.vmi.addi %row_base, %local_group + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.gather %ub_inv_scale[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %offset = pto.vmi.gather %ub_offset[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %shifted = pto.vmi.addf %scaled, %offset + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %shifted + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_dst[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp new file mode 100644 index 0000000000..aeaba21d8a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_int8_asym_64x128_kernel(__gm__ float *src, + __gm__ float *inv_scale, + __gm__ float *offset, __gm__ uint8_t *dst); + +void LaunchVmi_tquant_int8_asym_64x128_kernel(float *src, float *inv_scale, + float *offset, uint8_t *dst, + void *stream) { + vmi_tquant_int8_asym_64x128_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)inv_scale, (__gm__ float *)offset, + (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp new file mode 100644 index 0000000000..70bd93e436 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_int8_asym_64x128_kernel(float *src, float *inv_scale, + float *offset, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kRows * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + float *offsetHost = nullptr; + uint8_t *dstHost = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + float *offsetDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&offsetHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&offsetDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", scaleBytes, offsetHost, scaleBytes); + ReadFile("./v4.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(offsetDevice, scaleBytes, offsetHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_int8_asym_64x128_kernel(srcDevice, scaleDevice, + offsetDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(offsetDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(offsetHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py new file mode 100644 index 0000000000..85beaede4b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v4 idx={idx} " + f"golden={int(golden[idx]) if idx >= 0 else 'n/a'} " + f"output={int(output[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py new file mode 100644 index 0000000000..3e5bc93768 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 128 +RNG_SEED = 19 + + +def generate(output_dir: Path) -> None: + np.random.seed(RNG_SEED) + src = np.random.uniform(low=-2, high=2, size=(ROWS, COLS)).astype(np.float32) + scale = (np.max(np.abs(src), axis=1, keepdims=True) / np.float32(127.0)).astype(np.float32) + inv_scale = np.where(scale != 0, np.float32(1.0) / scale, np.float32(0.0)).astype(np.float32) + rounded = np.round(src * inv_scale).astype(np.float32) + golden = np.clip(rounded.astype(np.float16), -128, 127).astype(np.int8) + dst = np.full((ROWS, COLS), 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + inv_scale.reshape(ROWS).tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v4.bin") + golden.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto new file mode 100644 index 0000000000..70b04d7efe --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto @@ -0,0 +1,121 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_int8_sym_64x128_kernel(%src_gm: !pto.ptr, + %inv_scale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c36864_i64 = arith.constant 36864 : i64 + %c40960_i64 = arith.constant 40960 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_inv_scale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c40960_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %inv_scale_gm, %ub_inv_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %pair = %c0 to %c32 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %row_i32 = arith.index_cast %row : index to i32 + %gather_mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c7_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %local_group = pto.vmi.shrui %lane, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %row_base = pto.vmi.broadcast %row_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %param_indices = pto.vmi.addi %row_base, %local_group + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.gather %ub_inv_scale[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_dst[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp new file mode 100644 index 0000000000..24c20be676 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_int8_sym_64x128_kernel(__gm__ float *src, __gm__ float *inv_scale, + __gm__ uint8_t *dst); + +void LaunchVmi_tquant_int8_sym_64x128_kernel(float *src, float *inv_scale, + uint8_t *dst, void *stream) { + vmi_tquant_int8_sym_64x128_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)inv_scale, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp new file mode 100644 index 0000000000..3e4222a58e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_int8_sym_64x128_kernel(float *src, float *inv_scale, + uint8_t *dst, void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kRows * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *dstHost = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_int8_sym_64x128_kernel(srcDevice, scaleDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py new file mode 100644 index 0000000000..787c4a61d3 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 32 +E8M0_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row / np.float32(256.0), (ROWS, 1)).astype(np.float32) + golden_fp8 = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + golden_e8m0[:ROWS] = np.uint8(0x77) + + out_fp8 = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + out_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out_fp8.tofile(output_dir / "v2.bin") + out_e8m0.tofile(output_dir / "v3.bin") + golden_fp8.tofile(output_dir / "golden_v2.bin") + golden_e8m0.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto new file mode 100644 index 0000000000..1b880ec786 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_mxfp8_32x32_nd_kernel(%src_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %exp_mask_i32 = arith.constant 2139095040 : i32 + %shift_i32 = arith.constant 23 : i32 + %emax_i32 = arith.constant 8 : i32 + %scale_exp_bias_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_fp8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_fp8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_e8m0 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_fp8_gm, %ub_out_fp8_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c8 { + %elem_off = arith.muli %row, %c32 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_slot = arith.divui %row, %c8 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_out_e8m0[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xi32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scaling = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scaling_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_fp8_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_fp8_u8, %out_fp8_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_e8m0, %out_e8m0_gm, %c8_i64 + nburst(%c4_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp new file mode 100644 index 0000000000..1af7bbdc45 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_mxfp8_32x32_nd_kernel(__gm__ float *src, __gm__ uint8_t *out_fp8, + __gm__ uint8_t *out_e8m0); + +void LaunchVmi_tquant_mxfp8_32x32_nd_kernel(float *src, uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream) { + vmi_tquant_mxfp8_32x32_nd_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)out_fp8, + (__gm__ uint8_t *)out_e8m0); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp new file mode 100644 index 0000000000..827877248e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_mxfp8_32x32_nd_kernel(float *src, uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kE8m0Bytes = 256; + size_t srcBytes = kElems * sizeof(float); + size_t outFp8Bytes = kElems * sizeof(uint8_t); + size_t outE8m0Bytes = kE8m0Bytes * sizeof(uint8_t); + float *srcHost = nullptr; + uint8_t *outFp8Host = nullptr; + uint8_t *outE8m0Host = nullptr; + float *srcDevice = nullptr; + uint8_t *outFp8Device = nullptr; + uint8_t *outE8m0Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outFp8Host), outFp8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outE8m0Host), outE8m0Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outFp8Device, outFp8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outE8m0Device, outE8m0Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outFp8Bytes, outFp8Host, outFp8Bytes); + ReadFile("./v3.bin", outE8m0Bytes, outE8m0Host, outE8m0Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outFp8Device, outFp8Bytes, outFp8Host, outFp8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outE8m0Device, outE8m0Bytes, outE8m0Host, outE8m0Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_mxfp8_32x32_nd_kernel(srcDevice, outFp8Device, + outE8m0Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outFp8Host, outFp8Bytes, outFp8Device, outFp8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outE8m0Host, outE8m0Bytes, outE8m0Device, outE8m0Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outFp8Host, outFp8Bytes); + WriteFile("./v3.bin", outE8m0Host, outE8m0Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outFp8Device); + aclrtFree(outE8m0Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outFp8Host); + aclrtFreeHost(outE8m0Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py new file mode 100644 index 0000000000..cffec13f08 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py new file mode 100644 index 0000000000..e3daa6ae34 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 64 +GROUPS = ROWS * COLS // 32 +E8M0_BYTES = 256 +IDX_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def make_e8m0_zz_indices() -> np.ndarray: + index_array = np.arange(GROUPS, dtype=np.int64).reshape(ROWS, COLS // 32) + index_reshaped = index_array.reshape(ROWS // 16, 16, (COLS // 32) // 2, 2) + index_zz = np.transpose(index_reshaped, [0, 2, 1, 3]).flatten() + return (index_zz // 2)[::2].astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row / np.float32(256.0), (ROWS, 1)).astype(np.float32) + fp8_nd = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_fp8 = np.transpose(fp8_nd.reshape(ROWS, COLS // 32, 32), [1, 0, 2]).flatten() + + e8m0_nd = np.full((ROWS, COLS // 32), np.uint8(0x77), dtype=np.uint8) + e8m0_zz = np.transpose(e8m0_nd.reshape(ROWS // 16, 16, (COLS // 32) // 2, 2), [0, 2, 1, 3]).flatten() + golden_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + golden_e8m0[:GROUPS] = e8m0_zz + + idx = np.zeros(IDX_BYTES // np.dtype(np.uint16).itemsize, dtype=np.uint16) + zz_indices = make_e8m0_zz_indices() + idx[: zz_indices.size] = zz_indices + + out_fp8 = np.full(ROWS * COLS, SENTINEL_U8, dtype=np.uint8) + out_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + idx.tofile(output_dir / "v2.bin") + out_fp8.tofile(output_dir / "v3.bin") + out_e8m0.tofile(output_dir / "v4.bin") + golden_fp8.tofile(output_dir / "golden_v3.bin") + golden_e8m0.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto new file mode 100644 index 0000000000..fe9e8ba375 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto @@ -0,0 +1,172 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_mxfp8_32x64_nz_kernel(%src_gm: !pto.ptr, + %idx_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c1056 = arith.constant 1056 : index + %c1024 = arith.constant 1024 : index + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c131328_i64 = arith.constant 131328 : i64 + %c131584_i64 = arith.constant 131584 : i64 + %c196864_i64 = arith.constant 196864 : i64 + %c197120_i64 = arith.constant 197120 : i64 + %exp_mask_i32 = arith.constant 2139095040 : i32 + %shift_i32 = arith.constant 23 : i32 + %emax_i32 = arith.constant 8 : i32 + %scale_exp_bias_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_fp8_nd_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c131328_i64 : i64 -> !pto.ptr + %ub_e8m0_nd_u8 = pto.castptr %c131584_i64 : i64 -> !pto.ptr + %ub_e8m0_nd_u16 = pto.castptr %c131584_i64 : i64 -> !pto.ptr + %ub_e8m0_zz_u8 = pto.castptr %c196864_i64 : i64 -> !pto.ptr + %ub_e8m0_zz_u16 = pto.castptr %c196864_i64 : i64 -> !pto.ptr + %ub_idx_u16 = pto.castptr %c197120_i64 : i64 -> !pto.ptr + %out_fp8_hi_gm = pto.addptr %out_fp8_gm, %c1024 + : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %idx_gm, %ub_idx_u16, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask_f32 = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + scf.for %row = %c0 to %c32 step %c4 { + %src_off = arith.muli %row, %c64 : index + %scale_off = arith.muli %row, %c2 : index + %nd_off = arith.muli %row, %c64 : index + + %x = pto.vmi.load %ub_src[%src_off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask_f32 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> + %exp_mask = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %shift = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %emax = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 + : i32 -> !pto.vmi.vreg<8xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + pto.vmi.group_store %amax, %ub_max[%scale_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> + %scaling = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> + %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scaling_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_fp8_nd_f8[%nd_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %max256 = pto.vmi.load %ub_max[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %max256_bits = pto.vmi.bitcast %max256 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask256 = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift256 = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax256 = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp256_bits = pto.vmi.andi %max256_bits, %exp_mask256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp256 = pto.vmi.shrui %exp256_bits, %shift256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_256_i32 = pto.vmi.subi %exp256, %emax256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_256_u8 = pto.vmi.trunci %e8m0_256_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %e8m0_256_u8, %ub_e8m0_nd_u8[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + + %idx_mask = pto.vmi.create_mask %c32 + : index -> !pto.vmi.mask<128xpred> + %idx_vec = pto.vmi.load %ub_idx_u16[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xui16> + %e8m0_zz = pto.vmi.gather %ub_e8m0_nd_u16[%idx_vec], %idx_mask, %idx_vec + : !pto.ptr, !pto.vmi.vreg<128xui16>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xui16> + -> !pto.vmi.vreg<128xui16> + pto.vmi.store %e8m0_zz, %ub_e8m0_zz_u16[%c0] + : !pto.vmi.vreg<128xui16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_fp8_nd_f8, %out_fp8_gm, %c32_i64 + nburst(%c32_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + %ub_fp8_nd_hi_f8 = pto.addptr %ub_fp8_nd_f8, %c32 + : !pto.ptr -> !pto.ptr + pto.mte_ub_gm %ub_fp8_nd_hi_f8, %out_fp8_hi_gm, %c32_i64 + nburst(%c32_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_e8m0_zz_u8, %out_e8m0_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp new file mode 100644 index 0000000000..4959bd44b0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_mxfp8_32x64_nz_kernel(__gm__ float *src, __gm__ uint16_t *idx, + __gm__ uint8_t *out_fp8, + __gm__ uint8_t *out_e8m0); + +void LaunchVmi_tquant_mxfp8_32x64_nz_kernel(float *src, uint16_t *idx, + uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream) { + vmi_tquant_mxfp8_32x64_nz_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint16_t *)idx, + (__gm__ uint8_t *)out_fp8, (__gm__ uint8_t *)out_e8m0); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp new file mode 100644 index 0000000000..9440162a71 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_mxfp8_32x64_nz_kernel(float *src, uint16_t *idx, + uint8_t *out_fp8, uint8_t *out_e8m0, + void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kE8m0Bytes = 256; + constexpr size_t kIdxBytes = 256; + size_t srcBytes = kElems * sizeof(float); + size_t idxBytes = kIdxBytes; + size_t outFp8Bytes = kElems * sizeof(uint8_t); + size_t outE8m0Bytes = kE8m0Bytes * sizeof(uint8_t); + float *srcHost = nullptr; + uint16_t *idxHost = nullptr; + uint8_t *outFp8Host = nullptr; + uint8_t *outE8m0Host = nullptr; + float *srcDevice = nullptr; + uint16_t *idxDevice = nullptr; + uint8_t *outFp8Device = nullptr; + uint8_t *outE8m0Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&idxHost), idxBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outFp8Host), outFp8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outE8m0Host), outE8m0Bytes)); + ACL_CHECK( + aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&idxDevice, idxBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outFp8Device, outFp8Bytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outE8m0Device, outE8m0Bytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", idxBytes, idxHost, idxBytes); + ReadFile("./v3.bin", outFp8Bytes, outFp8Host, outFp8Bytes); + ReadFile("./v4.bin", outE8m0Bytes, outE8m0Host, outE8m0Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idxDevice, idxBytes, idxHost, idxBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outFp8Device, outFp8Bytes, outFp8Host, outFp8Bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outE8m0Device, outE8m0Bytes, outE8m0Host, outE8m0Bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_mxfp8_32x64_nz_kernel(srcDevice, idxDevice, outFp8Device, + outE8m0Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outFp8Host, outFp8Bytes, outFp8Device, outFp8Bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outE8m0Host, outE8m0Bytes, outE8m0Device, outE8m0Bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", outFp8Host, outFp8Bytes); + WriteFile("./v4.bin", outE8m0Host, outE8m0Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idxDevice); + aclrtFree(outFp8Device); + aclrtFree(outE8m0Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(idxHost); + aclrtFreeHost(outFp8Host); + aclrtFreeHost(outE8m0Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py new file mode 100644 index 0000000000..24d554e100 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32() -> bool: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-5, rtol=1e-5): + return True + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_f16() -> bool: + golden = np.fromfile("golden_v3.bin", dtype=np.float16) + output = np.fromfile("v3.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check_f32() or not check_f16(): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py new file mode 100644 index 0000000000..6a28077ea8 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +ACTIVE = 96 +SEED = 29 +SENTINEL32 = np.float32(-901.25) +SENTINEL16 = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + out32 = np.full(ELEMS, SENTINEL32, dtype=np.float32) + out16 = np.full(ELEMS, SENTINEL16, dtype=np.float16) + golden32 = out32.copy() + golden16 = out16.copy() + golden32[:ACTIVE] = src[:ACTIVE] + golden16[:ACTIVE] = src[:ACTIVE].astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out32.tofile(output_dir / "v2.bin") + out16.tofile(output_dir / "v3.bin") + golden32.tofile(output_dir / "golden_v2.bin") + golden16.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto new file mode 100644 index 0000000000..f9362793ec --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_granularity_f32_f16_store_kernel(%src_gm: !pto.ptr, + %out32_gm: !pto.ptr, + %out16_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c96 = arith.constant 96 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out32 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out32_gm, %ub_out32, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out16_gm, %ub_out16, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %ub_out32[%c0], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %ub_out16[%c0], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out32, %out32_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out16, %out16_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp new file mode 100644 index 0000000000..de0c069797 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_granularity_f32_f16_store_kernel(__gm__ float *src, + __gm__ float *out32, + __gm__ half *out16); + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream) { + vmi_mask_granularity_f32_f16_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)out32, (__gm__ half *)out16); +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp new file mode 100644 index 0000000000..2a65d8c46d --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t out32Bytes = kElems * sizeof(float); + size_t out16Bytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *out32Host = nullptr; + uint16_t *out16Host = nullptr; + float *srcDevice = nullptr; + float *out32Device = nullptr; + uint16_t *out16Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out32Host), out32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out16Host), out16Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out32Device, out32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out16Device, out16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", out32Bytes, out32Host, out32Bytes); + ReadFile("./v3.bin", out16Bytes, out16Host, out16Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out32Device, out32Bytes, out32Host, out32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out16Device, out16Bytes, out16Host, out16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_granularity_f32_f16_store_kernel(srcDevice, out32Device, + out16Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(out32Host, out32Bytes, out32Device, out32Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out16Host, out16Bytes, out16Device, out16Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", out32Host, out32Bytes); + WriteFile("./v3.bin", out16Host, out16Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(out32Device); + aclrtFree(out16Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(out32Host); + aclrtFreeHost(out16Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-select-store/compare.py b/test/vpto/cases/vmi/mask-select-store/compare.py new file mode 100644 index 0000000000..b9e3290e76 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v3", "v4"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/golden.py b/test/vpto/cases/vmi/mask-select-store/golden.py new file mode 100644 index 0000000000..19ce1ebe2c --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 64 +ACTIVE = 48 +SEED = 29 +SENTINEL = np.float32(-901.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + rhs = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + masked = np.full(ELEMS, SENTINEL, dtype=np.float32) + summed = (src + rhs).astype(np.float32) + golden_dense = src.copy() + golden_dense[:ACTIVE] = summed[:ACTIVE] + golden_masked = masked.copy() + golden_masked[:ACTIVE] = summed[:ACTIVE] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + masked.tofile(output_dir / "v4.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + golden_masked.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/kernel.pto b/test/vpto/cases/vmi/mask-select-store/kernel.pto new file mode 100644 index 0000000000..51538fd4e0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_select_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dense_gm: !pto.ptr, + %masked_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c48 = arith.constant 48 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_masked = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %masked_gm, %ub_masked, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %rhs = pto.vmi.load %ub_rhs[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %rhs + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %ub_dense[%c0] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %ub_masked[%c0], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_masked, %masked_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-select-store/launch.cpp b/test/vpto/cases/vmi/mask-select-store/launch.cpp new file mode 100644 index 0000000000..d75d0da804 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_select_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dense, __gm__ float *masked); + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream) { + vmi_mask_select_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dense, + (__gm__ float *)masked); +} diff --git a/test/vpto/cases/vmi/mask-select-store/main.cpp b/test/vpto/cases/vmi/mask-select-store/main.cpp new file mode 100644 index 0000000000..07648040d0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream); + +int main() { + constexpr size_t kElems = 64; + size_t srcBytes = kElems * sizeof(float); + size_t rhsBytes = kElems * sizeof(float); + size_t denseBytes = kElems * sizeof(float); + size_t maskedBytes = kElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *denseHost = nullptr; + float *maskedHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *denseDevice = nullptr; + float *maskedDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&maskedHost), maskedBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&maskedDevice, maskedBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ReadFile("./v4.bin", maskedBytes, maskedHost, maskedBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(maskedDevice, maskedBytes, maskedHost, maskedBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_select_store_kernel(srcDevice, rhsDevice, denseDevice, + maskedDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(maskedHost, maskedBytes, maskedDevice, maskedBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", denseHost, denseBytes); + WriteFile("./v4.bin", maskedHost, maskedBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(denseDevice); + aclrtFree(maskedDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(denseHost); + aclrtFreeHost(maskedHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-select-store/ptoas.flags b/test/vpto/cases/vmi/mask-select-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto new file mode 100644 index 0000000000..e491e30698 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp new file mode 100644 index 0000000000..306dddada0 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp new file mode 100644 index 0000000000..089794a818 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_dense_group_users_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..c07d3c503e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_group_tail_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5b39bc3962 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_group_tail_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_group_tail_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..f9f224885e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto new file mode 100644 index 0000000000..4049b38720 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + return + } + + func.func @vmi_private_call_argument_boundary_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + func.call @consume(%x, %mask, %ub_sum, %c0) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp new file mode 100644 index 0000000000..ba6be566de --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_argument_boundary_store_kernel(__gm__ float *src, + __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream) { + vmi_private_call_argument_boundary_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp new file mode 100644 index 0000000000..5ce943feae --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_argument_boundary_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/private-call-inline-store/compare.py b/test/vpto/cases/vmi/private-call-inline-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/golden.py b/test/vpto/cases/vmi/private-call-inline-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/kernel.pto b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto new file mode 100644 index 0000000000..5e713650bc --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> + } + + func.func @vmi_private_call_inline_store_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = func.call @producer(%ub_src, %c0) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/launch.cpp b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp new file mode 100644 index 0000000000..b5015d7cda --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_inline_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_private_call_inline_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/main.cpp b/test/vpto/cases/vmi/private-call-inline-store/main.cpp new file mode 100644 index 0000000000..325ebc902e --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_inline_store_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py new file mode 100644 index 0000000000..7938574cd5 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 29 +SCALE = np.float32(0.5) +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden[:LOGICAL_ELEMS] = (src[:LOGICAL_ELEMS] * SCALE).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto new file mode 100644 index 0000000000..2920617624 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f16_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp new file mode 100644 index 0000000000..bf3aa91f10 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f16_tail_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_quant_f32_to_f16_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp new file mode 100644 index 0000000000..b03ccbce5c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f16_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py new file mode 100644 index 0000000000..9c36f02c73 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + golden = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto new file mode 100644 index 0000000000..4c7193f970 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_full_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %wide = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %ub_dst_f8[%c0] : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp new file mode 100644 index 0000000000..18bc01e2d1 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_full_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_full_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp new file mode 100644 index 0000000000..6e3aae53f2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 256; + constexpr size_t kDstElems = 256; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_full_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py new file mode 100644 index 0000000000..b662cd604f --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + packed = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden[:LOGICAL_ELEMS] = packed[:LOGICAL_ELEMS] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto new file mode 100644 index 0000000000..bb3db56ff2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %ub_dst_f8[%offset], %mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp new file mode 100644 index 0000000000..cf40a3fc57 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_tail_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp new file mode 100644 index 0000000000..5f5bda8502 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..ee2be3c731 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F16_VALUE = np.float16(0.125) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src_f16 = np.full(ELEMS, F16_VALUE, dtype=np.float16) + src_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded_f8 = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + reduction = np.sum(src_f16.astype(np.float32), dtype=np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..ae307ef525 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + %c256 = arith.constant 256 : index + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %src_f16 = pto.vmi.load %ub_f16[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %init = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<1xf32> + %sum = pto.vmi.reduce_addf %src_f16_f32, %init, %mask {reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<1xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<1xf32> + %sum_vec = pto.vmi.broadcast %sum + : !pto.vmi.vreg<1xf32> -> !pto.vmi.vreg<256xf32> + %src_f8 = pto.vmi.load %ub_f8[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %out, %ub_dst[%c0] : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..b882f9e0e2 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream) { + vmi_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e48cd97661 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py new file mode 100644 index 0000000000..bc9c97fdee --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +SEED = 37 +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = src.astype(np.float32) * np.float32(4.0) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto new file mode 100644 index 0000000000..3398ef3318 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_scf_for_loop_carried_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %packed = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %init = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %result, %ub_dst[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp new file mode 100644 index 0000000000..b0902d1207 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_scf_for_loop_carried_store_kernel(__gm__ uint16_t *src, __gm__ float *dst); + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_scf_for_loop_carried_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp new file mode 100644 index 0000000000..f45b070260 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_scf_for_loop_carried_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py new file mode 100644 index 0000000000..c964405de5 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol): + return True + close = np.isclose(golden, output, atol=atol, rtol=rtol) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2", 1e-4, 1e-4) or not check("v3", 0.0, 0.0): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py new file mode 100644 index 0000000000..b41d0e8681 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float16) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float16) + for row in range(ROWS): + src[row, :] = base + np.float16(row * 0.125) + + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_dense = src.astype(np.float32).reshape(-1) + golden_sum = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + golden_sum[row] = np.sum(src[row, :].astype(np.float32), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto new file mode 100644 index 0000000000..9b926ac640 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_widen_f16_to_f32_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<8xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<8xf32>, !pto.ptr + pto.vmi.store %x32, %ub_dense[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp new file mode 100644 index 0000000000..b0ee12da2b --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_widen_f16_to_f32_store_reduce_kernel(__gm__ half *src, __gm__ float *sum, + __gm__ float *dense); + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, + void *stream) { + vmi_widen_f16_to_f32_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)sum, (__gm__ float *)dense); +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp new file mode 100644 index 0000000000..96a4a102f8 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kSumElems = 8; + constexpr size_t kDenseElems = 128; + size_t srcBytes = kSrcElems * sizeof(uint16_t); + size_t sumBytes = kSumElems * sizeof(float); + size_t denseBytes = kDenseElems * sizeof(float); + uint16_t *srcHost = nullptr; + float *sumHost = nullptr; + float *denseHost = nullptr; + uint16_t *srcDevice = nullptr; + float *sumDevice = nullptr; + float *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_widen_f16_to_f32_store_reduce_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 5bde9442e7..3732f72313 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(ptobc) add_subdirectory(ptoas) +add_subdirectory(pto-test-opt) diff --git a/tools/pto-test-opt/CMakeLists.txt b/tools/pto-test-opt/CMakeLists.txt new file mode 100644 index 0000000000..8f72f0383d --- /dev/null +++ b/tools/pto-test-opt/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set(LLVM_LINK_COMPONENTS + Support +) + +add_llvm_executable(pto-test-opt + pto-test-opt.cpp +) + +target_link_libraries(pto-test-opt PRIVATE + PTOIR + PTOTransforms + MLIRMlirOptMain + MLIRIR + MLIRParser + MLIRPass + MLIRSupport + MLIRFuncDialect + MLIRArithDialect + MLIRMemRefDialect + MLIRSCFDialect + MLIRControlFlowDialect +) + +add_dependencies(pto-test-opt + PTOOpsIncGen + PTOPassesIncGen +) diff --git a/tools/pto-test-opt/pto-test-opt.cpp b/tools/pto-test-opt/pto-test-opt.cpp new file mode 100644 index 0000000000..6ec1dc70ef --- /dev/null +++ b/tools/pto-test-opt/pto-test-opt.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- pto-test-opt.cpp - PTO lit pass runner -----------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + + mlir::registerAllPasses(); + mlir::pto::registerPTOPasses(); + + return failed(mlir::MlirOptMain(argc, argv, "PTO lit pass runner\n", + registry)); +} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 5bb6821677..bc4a401234 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -8,6 +8,7 @@ #include "ptoas.h" #include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" @@ -17,6 +18,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser/Parser.h" @@ -33,6 +36,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Target/Cpp/CppEmitter.h" +#include "mlir/Transforms/Passes.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/FileSystem.h" // [Fix] Required for OF_None @@ -433,6 +437,12 @@ static llvm::cl::opt disableInferLayout( llvm::cl::desc("Disable PTO layout inference pass (static-only)"), llvm::cl::init(false)); +static llvm::cl::opt enableVMI( + "enable-vmi", + llvm::cl::desc("Run the experimental VMI-to-VPTO semantic pipeline " + "(requires --pto-backend=vpto or pto.backend = \"vpto\")"), + llvm::cl::init(false)); + static llvm::cl::opt emitAddPtrTrace( "emit-addptr-trace", llvm::cl::desc("Emit addptr trace comments in generated C++ output"), @@ -1468,6 +1478,7 @@ static void prepareVPTOForEmission(PassManager &pm) { kernelModulePM.addPass(createCSEPass()); kernelModulePM.addPass(pto::createVPTOPtrNormalizePass()); kernelModulePM.addPass(pto::createVPTOPtrCastCleanupPass()); + kernelModulePM.addPass(pto::createVPTONormalizeEquivalentVcvtPass()); kernelModulePM.addPass(createReconcileUnrealizedCastsPass()); kernelModulePM.addNestedPass( createVPTOExpandWrapperOpsPass()); @@ -1585,6 +1596,188 @@ static LogicalResult runVPTOBackendPipeline(OwningOpRef &module, return success(); } +static bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIType) || + llvm::any_of(functionType.getResults(), containsVMIType); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +static LogicalResult verifyNoPublicVMISignature(ModuleOp module) { + WalkResult result = module.walk([&](func::FuncOp func) { + if (!func.isPublic() || !containsVMIType(func.getFunctionType())) + return WalkResult::advance(); + func.emitError() + << pto::kVMIDiagLayoutContractPrefix + << "public VMI typed function requires an explicit external ABI " + "materialization plan"; + return WalkResult::interrupt(); + }); + return failure(result.wasInterrupted()); +} + +static bool containsVMIPhysicalType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIPhysicalType) || + llvm::any_of(functionType.getResults(), containsVMIPhysicalType); + } + return false; +} + +static bool isPrivatePhysicalVMIHelper(func::FuncOp func) { + return !func.isPublic() && !func.isExternal() && + func.getBody().hasOneBlock() && + containsVMIPhysicalType(func.getFunctionType()); +} + +static LogicalResult inlinePrivatePhysicalVMIHelperCall(func::CallOp call, + func::FuncOp callee) { + if (callee.isExternal()) + return call.emitOpError("callee must have a body before inlining"); + if (!callee.getBody().hasOneBlock()) + return call.emitOpError("callee must be single-block before inlining"); + + Block &entry = callee.getBody().front(); + if (entry.getNumArguments() != call.getNumOperands()) + return call.emitOpError("callee argument count mismatch during inlining"); + + auto returnOp = dyn_cast(entry.getTerminator()); + if (!returnOp) + return call.emitOpError("callee must terminate with func.return"); + if (returnOp.getNumOperands() != call.getNumResults()) + return call.emitOpError("callee return/result arity mismatch during inlining"); + + OpBuilder builder(call); + IRMapping mapping; + for (auto [arg, operand] : llvm::zip(entry.getArguments(), call.getOperands())) + mapping.map(arg, operand); + + for (Operation &op : entry.without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + for (auto [oldResult, newResult] : + llvm::zip(op.getResults(), newOp->getResults())) + mapping.map(oldResult, newResult); + } + + for (auto [callResult, returnOperand] : + llvm::zip(call.getResults(), returnOp.getOperands())) + callResult.replaceAllUsesWith(mapping.lookup(returnOperand)); + + call.erase(); + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpersInModule(ModuleOp module) { + bool madeProgress = true; + while (madeProgress) { + madeProgress = false; + + SmallVector calls; + module.walk([&](func::CallOp call) { calls.push_back(call); }); + + for (func::CallOp call : calls) { + if (!call || !call->getBlock()) + continue; + + func::FuncOp caller = call->getParentOfType(); + auto calleeAttr = call.getCalleeAttr(); + if (!caller || !calleeAttr) + continue; + + func::FuncOp callee = + SymbolTable::lookupNearestSymbolFrom( + call, calleeAttr.getAttr()); + if (!callee || !isPrivatePhysicalVMIHelper(callee)) + continue; + if (callee == caller) + return call.emitOpError("recursive private VMI helper call cannot be " + "inlined before VPTO emission"); + + if (failed(inlinePrivatePhysicalVMIHelperCall(call, callee))) + return failure(); + madeProgress = true; + } + } + + SymbolTable symbolTable(module); + SmallVector deadFuncs; + for (func::FuncOp func : module.getOps()) { + if (!isPrivatePhysicalVMIHelper(func)) + continue; + auto uses = symbolTable.getSymbolUses(func, module); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpers(ModuleOp module) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(module))) + return failure(); + WalkResult result = module.walk([&](ModuleOp nestedModule) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(nestedModule))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { + if (failed(verifyNoPublicVMISignature(module.get()))) + return failure(); + + PassManager pm(module->getContext()); + pm.enableVerifier(); + pm.addPass(pto::createPTOValidateVMIIRPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMIPreAssignmentCombinePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILegalizeArithSelectPass()); + pm.addPass(pto::createVMILayoutAssignmentPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutRematerializePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutFoldPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutSinkMaterializationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILegalizeArithSelectPass()); + pm.addPass(pto::createPTOValidateVMILayoutIRPass()); + pm.addPass(pto::createVMIToVPTOPass()); + pm.addPass(pto::createVPTONormalizeEquivalentVcvtPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions(pm, + "VMI-to-VPTO pipeline"))) + return failure(); + if (failed(pm.run(module.get()))) { + llvm::errs() << "Error: VMI-to-VPTO pipeline failed.\n"; + return failure(); + } + if (failed(inlinePrivatePhysicalVMIHelpers(module.get()))) { + llvm::errs() << "Error: failed to inline private VMI physical helpers.\n"; + return failure(); + } + return success(); +} + int mlir::pto::compilePTOASModule( OwningOpRef &module, PTOASContext &context, PTOBackend effectiveBackend, PTOASCompileResult &result, @@ -1600,6 +1793,11 @@ int mlir::pto::compilePTOASModule( "--pto-backend=vpto or pto.backend = \"vpto\".\n"; return 1; } + if (enableVMI && effectiveBackend != PTOBackend::VPTO) { + llvm::errs() << "Error: --enable-vmi requires --pto-backend=vpto or " + "pto.backend = \"vpto\".\n"; + return 1; + } PTOBuildLevel effectiveLevel = defaultBuildLevel(); if (!parseBuildLevel(ptoBuildLevel, effectiveLevel)) { @@ -1718,6 +1916,11 @@ int mlir::pto::compilePTOASModule( const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); const bool hasTilelangHelpers = hasTilelangInlineHelpers(*module); + if (enableVMI) { + if (failed(runVMISemanticPipeline(module))) + return 1; + } + if (effectiveBackend == PTOBackend::VPTO && !hasTileOpsToExpand) { if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { llvm::errs() << "Error: shared pre-backend seam IR is unavailable when "