diff --git a/ptodsl/README.md b/ptodsl/README.md index c2e034ac5..ea25c280d 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -152,6 +152,24 @@ Direct run on a real NPU: python3 ptodsl/examples/flash_attention_softmax_launch.py ``` +### `rms_norm/rmsnorm_alloc_buffer_simt.py` + +Compile-only RMSNorm example for explicit-mode SIMT kernels. It exercises +SIMT-local `pto.alloc_buffer(...)`, hand-authored dynamic UB scratch offsets, +contiguous `scalar.load` / `scalar.store`, `pto.vec`, +`pto.simt_allreduce_sum(...)`, explicit pipe `set_flag` / `wait_flag` sync, +and a runtime token loop that lowers to `scf.for`. + +```bash +python3 ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt.py --variant x128 > /tmp/rmsnorm_x128.mlir +python3 ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt.py --variant x64 > /tmp/rmsnorm_x64.mlir +``` + +Expected: MLIR containing `@rmsnorm_4096_alloc_buffer_simt_context_kernel`, +`scf.for`, `vector<4xf32>` for both `x128` and `x64`, and inline +`pto.redux_add` / `pto.syncthreads` allreduce ops. The main token loop should also contain dynamic +`pto.set_flag_dyn` / `pto.wait_flag_dyn` operations for the ping-pong events. + ### Launch artifacts - `~/.cache/ptodsl/` — JIT-compiled kernel `.so` cache @@ -167,6 +185,7 @@ python3 ptodsl/tests/test_jit_compile.py python3 ptodsl/tests/test_jit_diagnostics.py python3 ptodsl/tests/test_subkernel_diagnostics.py python3 ptodsl/tests/test_flash_attention_demo_compile.py +python3 ptodsl/tests/test_rmsnorm_example_compile.py python3 ptodsl/tests/test_ptoas_frontend_verify.py python3 ptodsl/tests/test_docs_as_test.py ``` @@ -178,6 +197,7 @@ ptodsl_jit_compile: PASS ptodsl_jit_diagnostics: PASS ptodsl_subkernel_diagnostics: PASS ptodsl_flash_attention_demo_compile: PASS +ptodsl_rmsnorm_example_compile: PASS ptodsl_ptoas_frontend_verify: PASS ptodsl_docs_as_test: PASS ``` diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 12eba5ce5..d7d29e885 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -257,7 +257,9 @@ These are hardware-bound compute sub-kernels, each mapped to a specific NPU comp Each can be invoked as a named decorated function (`@pto.cube` / `@pto.simd` / `@pto.simt`) or inline as a context manager -(`with pto.cube():`, `with pto.simd():`, `with pto.simt():`). +(`with pto.cube():`, `with pto.simd():`, `with pto.simt():`). Inline SIMT +scopes can also spell launch dimensions directly with +`with pto.simt(dim_x, dim_y, dim_z):`. The boundary contract is strict: vreg values do not escape a simd kernel, cube-local state does not leak into UB, and data crosses layer boundaries only through UB-backed tiles or typed UB pointers. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 7bb02647c..b588e7dad 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -736,8 +736,9 @@ two ways: 1. **As decorated functions** — reusable, named sub-kernels called from `@pto.jit` entries and modules. -2. **As context managers** (`with pto.cube():`, etc.) — inline blocks for - one-off snippets (see Section 3.8). +2. **As context managers** (`with pto.cube():`, `with pto.simd():`, + `with pto.simt():`, and `with pto.simt(dim_x, dim_y, dim_z):`) — inline + blocks for one-off snippets (see Section 3.8). Named sub-kernel decorators use the same default AST rewrite model as `@pto.jit`: supported Python `if` and `for range(...)` statements lower to @@ -997,10 +998,13 @@ Specific SIMT micro-op APIs are documented in Chapter 13. ## 3.8 Inline context manager syntax -In addition to the decorator form, each sub-kernel unit provides a context -manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These -open one-off anonymous sub-kernel bodies without requiring a separate named -Python function. Inline scopes are supported in top-level `@pto.jit` bodies. +In addition to the decorator form, each sub-kernel unit provides an inline +context manager form: `with pto.cube():`, `with pto.simd():`, +`with pto.simt():`, and `with pto.simt(dim_x, dim_y, dim_z):`. These open +one-off anonymous sub-kernel bodies without requiring a separate named Python +function. Inline scopes are supported in top-level `@pto.jit` bodies. The +dimensioned SIMT form uses the same inline body style while making the caller +emit an explicit `pto.simt_launch`. ### Syntax @@ -1022,6 +1026,12 @@ with pto.simt(): scalar.store(o_next, o_next_tile[row, col]) ``` +```python +with pto.simt(128, 1, 1): + tid = pto.get_tid_x() + scalar.store(tid, scratch_ub, scalar.index_cast(tid)) +``` + ```python with pto.cube(): @@ -1041,6 +1051,9 @@ with pto.cube(): / `pto.section.cube` bodies inside the outlined helper. - `with pto.simt():` preserves its scalar body inside one outlined `pto.simt_entry` helper, and the caller emits `pto.store_vfsimt_info`. +- `with pto.simt(dim_x, dim_y, dim_z):` uses the same inline outlining and + automatic capture rules, but emits a caller-side explicit SIMT launch with + the authored dimensions. - Values defined inside the inline sub-kernel cannot escape the block directly. Use Tiles, typed pointers, or other mutable references to communicate results back to the caller. diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index 63dd8368a..3c0506bb7 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -175,7 +175,31 @@ ptr_ub = pto.ptr(pto.f16, pto.MemorySpace.UB) | `MemorySpace.ACC` | Cube L0C accumulator buffer | | `MemorySpace.BIAS` | Cube bias table buffer | -## 4.5 TensorView +## 4.5 Explicit scratch buffers + +Allocate SIMT lane-local scratch storage for pointer-style load and store +operations inside a SIMT helper. + +```text +pto.alloc_buffer(shape, dtype) +``` + + +```python +scratch = pto.alloc_buffer((32,), pto.f32) +``` + +| Parameter | Description | +|-----------|-------------| +| `shape` | Static positive integer shape. Pass an `int`, `tuple[int, ...]`, or `list[int]`. | +| `dtype` | Element type of the returned buffer, such as `pto.f32` or `pto.i32`. | + +The returned pointer names a local allocation in the SIMT helper invocation +that allocates it. Use this for per-workitem temporary fragments, scalar +scratch values, or staged values that are accessed through pointer-style loads +and stores. + +## 4.6 TensorView `TensorView` is a descriptor for a tensor in Global Memory. Create one inside a `@pto.jit` body with `make_tensor_view`: @@ -205,7 +229,7 @@ def kernel( Strides support non-contiguous tensors. Pass `strides=A.strides` from the source tensor for the default row-major layout, or supply explicit strides for sub-views. Use `tv.as_ptr()` to obtain a typed GM pointer for use with MTE Ops in explicit-mode orchestration. -## 4.6 PartitionTensorView +## 4.7 PartitionTensorView `partition_view` creates a sub-view of a TensorView at a given offset and size. It describes *which part* of the GM tensor a `tile.load` or `tile.store` should operate on: @@ -216,7 +240,7 @@ part = pto.partition_view(tv, offsets=[row_offset, 0], sizes=[BLOCK, dim]) The result is a `PartitionTensorView` — a lightweight descriptor, not a data buffer. It carries the partition's shape, strides, and element type (inherited from the source TensorView). Use `part.as_ptr()` to obtain a typed GM pointer for MTE Ops in explicit-mode orchestration. -## 4.7 Tile +## 4.8 Tile A `Tile` is an on-chip buffer allocated in UB or cube-local memory. Allocate tiles with `alloc_tile`: diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md index 7544591c0..4db3e5acb 100644 --- a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -35,7 +35,9 @@ When in doubt, ask: *can this value change between launches of the same compiled ## 6.2 Scalar access: load and store -`scalar.load` reads a single scalar element from a typed pointer or tile location. `scalar.store` writes a scalar back. These are the canonical scalar memory ops for SIMT authoring. The offset is counted in elements, not bytes. +`scalar.load` reads one scalar element from a typed pointer or tile location. +`scalar.store` writes one scalar element back. These are the canonical scalar +memory ops for SIMT authoring. Offsets are counted in elements, not bytes. #### `scalar.load(ptr: PtrType, offset: Index) -> ScalarType` @@ -101,6 +103,67 @@ scalar.store(value, tile[row, col]) scalar.store(value, ptr, offset) ``` +### Contiguous vector access + +Use `contiguous=N` when a single work-item should read or write `N` adjacent +elements as one vector value. `N` must be a positive Python integer greater than +`1`. + +#### `scalar.load(ptr: PtrType, offset: Index, *, contiguous: int) -> VecValue` + +**Description**: Loads `contiguous` adjacent elements from a typed pointer. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed source pointer | +| `offset` | `Index` | First element to load | +| `contiguous` | Positive Python `int` greater than `1` | Number of adjacent elements to load | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `pto.vec(T, N)` | Vector value with `N == contiguous` and element type `T` | + +**Example**: + +```python +x4 = scalar.load(ptr, offset, contiguous=4) +``` + +For a `pto.ptr(pto.f32, "ub")`, this produces a DSL vector value with type +`pto.vec(pto.f32, 4)`. + +--- + +#### `scalar.store(value: VecValue, ptr: PtrType, offset: Index, *, contiguous: int | None = None) -> None` + +**Description**: Stores a vector value to adjacent elements of a typed pointer. +The store width is taken from the vector lane count. If `contiguous` is +provided, it must match that lane count. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `pto.vec(T, N)` | Vector value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | First element to store | +| `contiguous` | `int` or `None` | Optional width check; when provided, it must equal `N` | + +**Example**: + +```python +scalar.store(x4, ptr, offset) +scalar.store(x4, ptr, offset, contiguous=4) # optional width check +``` + +`scalar.store(scalar_value, ptr, offset, contiguous=N)` is rejected because +scalar values are not implicitly broadcast for vector stores. To build an +explicit broadcast vector, use `pto.vec(...)`; see Section 8.4. + ### Scalar value adaptation `scalar.store` adapts the authored `value` to the destination element type. diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md index ac4e32a3a..5c14332ca 100644 --- a/ptodsl/docs/user_guide/08-compute-operations.md +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -1864,3 +1864,43 @@ The `mte_l1_l0a`/`mte_l1_l0b` stage operands from the authored source tiles into | `pto.mad_mx_bias(lhs, rhs, dst, bias, m, n, k, **clauses)` | MX-format bias-init matmul | MX variants require MX-enabled dtypes (f8) and pre-loaded scale payloads. For most users, the standard `mad`, `mad_acc`, and `mad_bias` are the primary interface. + +--- + +## 8.4 Builtin vector values + +Builtin vector values are small fixed-lane vectors used by contiguous scalar +accesses and element-wise vector expressions. They are distinct from the +`VRegType` values used inside `@pto.simd` kernels. + +#### `pto.vec(dtype, lanes, *, init=None)` + +**Description**: Names a builtin vector type. When `init` is provided, +constructs a vector value. A scalar initializer is broadcast to every lane. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | PTO dtype | Element type, such as `pto.f32` | +| `lanes` | Positive Python `int` | Number of lanes | +| `init` | Scalar value, vector value, or `None` | Optional initializer; scalar values are broadcast to all lanes | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | Vector type or `pto.vec(dtype, lanes)` value | Without `init`, returns a vector type descriptor; with `init`, returns a vector value | + +**Example**: + + +```python +x4 = scalar.load(ptr, offset, contiguous=4) +rstd4 = pto.vec(pto.f32, 4, init=rstd) +y4 = x4 * rstd4 +scalar.store(y4, ptr, offset) +``` + +Use this form when a scalar value must participate in element-wise arithmetic +with a vector value returned by `scalar.load(..., contiguous=N)`. diff --git a/ptodsl/docs/user_guide/13-simt-micro-ops.md b/ptodsl/docs/user_guide/13-simt-micro-ops.md index 928cffeba..18eb19601 100644 --- a/ptodsl/docs/user_guide/13-simt-micro-ops.md +++ b/ptodsl/docs/user_guide/13-simt-micro-ops.md @@ -10,8 +10,8 @@ scalar values loaded from tiles. #### `pto.store_vfsimt_info(dim_z, dim_y, dim_x) -> None` **Description**: Emits the low-level VPTO launch descriptor operation. Most -code should use `body[dim_x, dim_y, dim_z](...)` or `pto.simt_launch(...)` -instead. +code should use `body[dim_x, dim_y, dim_z](...)`, `pto.simt_launch(...)`, or +the inline form `with pto.simt(dim_x, dim_y, dim_z):` instead. **Parameters**: diff --git a/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt.py b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt.py new file mode 100644 index 000000000..eae64dde0 --- /dev/null +++ b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +RMSNorm compile-only PTODSL example for issue 483. + +The example exercises the PTODSL surfaces needed by the RMSNorm SimtVF kernel: + +- ``pto.alloc_buffer(...)`` for lane-local SIMT fragment storage +- hand-authored dynamic UB scratch layout via ``pto.castptr`` / ``pto.addptr`` +- contiguous scalar ``load`` / ``store`` vector accesses +- ``pto.simt_allreduce_sum(...)`` for cross-workitem sum reduction +- W stays in UB after the GM->UB preload and is read directly by the token SIMT body +- runtime ``range(...)`` for the token loop so the AST rewrite emits ``scf.for`` +- Python ``range(...)`` loops inside SIMT helpers to emit compact runtime loops + +Run this file directly to print the emitted MLIR for one specialization. +""" + +import argparse +from pathlib import Path +import sys + + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from rmsnorm_alloc_buffer_simt.py" + ) + + +from ptodsl import pto, scalar + + +@pto.simt +def rmsnorm_simt_token_body( + x_ub, + y_ub, + rstd_ub, + reduce_scratch, + w_ub, + eps: pto.f32, + pingpong: pto.i32, + *, + threads: pto.const_expr = 128, + rounds: pto.const_expr = 16, + lanes: pto.const_expr = 2, + hidden_size: pto.const_expr = 4096, +): + tx = pto.get_tid_x() + frag_elems: pto.const_expr = rounds * lanes + x_frag = pto.alloc_buffer((frag_elems,), pto.f32) + sum_sq = pto.alloc_buffer((1,), pto.f32) + + for r in range(0, rounds): + lane_offset = r * threads * lanes + tx * lanes + x_offset = pingpong * hidden_size + lane_offset + frag_offset = r * lanes + + x_vec = scalar.load(x_ub, x_offset, contiguous=lanes) + scalar.store(x_vec, x_frag, frag_offset) + + scalar.store(pto.const(0.0, dtype=pto.f32), sum_sq, 0) + + for i in range(0, frag_elems): + local_sum = scalar.load(sum_sq, 0) + x = scalar.load(x_frag, i) + local_sum = local_sum + x * x + scalar.store(local_sum, sum_sq, 0) + + local_sum = scalar.load(sum_sq, 0) + + sum_sq = pto.simt_allreduce_sum( + local_sum, + threads=threads, + scale=1, + thread_offset=0, + scratch=reduce_scratch, + ) + + rstd = 1.0 / pto.sqrt(sum_sq / hidden_size + eps) + + scalar.store(rstd, rstd_ub, pingpong * 8) + + for r in range(0, rounds): + round_offset = r * threads * lanes + thread_offset = tx * lanes + lane_base = round_offset + thread_offset + y_offset = pingpong * hidden_size + lane_base + frag_offset = r * lanes + + x_vec = scalar.load(x_frag, frag_offset, contiguous=lanes) + w_vec = scalar.load(w_ub, lane_base, contiguous=lanes) + rstd_vec = pto.vec(pto.f32, lanes, init=rstd) + y_vec = x_vec * rstd_vec * w_vec + scalar.store(y_vec, y_ub, y_offset) + + +@pto.jit(target="a5", mode="explicit", dyn_shared_memory_buf=82496) +def rmsnorm_4096_alloc_buffer_simt_context_kernel( + X: pto.ptr(pto.f32, "gm"), + Y: pto.ptr(pto.f32, "gm"), + W: pto.ptr(pto.f32, "gm"), + RSTD: pto.ptr(pto.f32, "gm"), + eps: pto.f32, + *, + threads: pto.const_expr = 128, + rounds: pto.const_expr = 8, + lanes: pto.const_expr = 4, + hidden_size: pto.const_expr = 4096, + n_cores: pto.const_expr = 64, + tokens_per_core: pto.const_expr = 64, + f32_bytes: pto.const_expr = 4, +): + assert threads * rounds * lanes == hidden_size, ( + "threads * rounds * lanes must equal hidden_size for RMSNorm SIMT partitioning" + ) + + core_id = pto.get_block_idx() + + ub_base = pto.castptr(pto.const(0, dtype=pto.ui64), pto.ptr(pto.f32, "ub")) + w_ub = pto.addptr(ub_base, 0) + reduce_scratch = pto.addptr(ub_base, hidden_size) + x_ub = pto.addptr(ub_base, hidden_size + 128) + y_ub = pto.addptr(ub_base, hidden_size + 128 + 2 * hidden_size) + rstd_ub = pto.addptr(ub_base, hidden_size + 128 + 4 * hidden_size) + + pto.mte_gm_ub( + W, + w_ub, + 0, + hidden_size * f32_bytes, + nburst=(1, hidden_size * f32_bytes, hidden_size * f32_bytes), + ) + pto.set_flag("MTE2", "V", event_id=3) + pto.wait_flag("MTE2", "V", event_id=3) + + pto.set_flag("V", "MTE2", event_id=0) + pto.set_flag("MTE3", "V", event_id=0) + pto.set_flag("V", "MTE2", event_id=1) + pto.set_flag("MTE3", "V", event_id=1) + + for local_token in range(0, tokens_per_core): + token_id = local_token * n_cores + core_id + pingpong = local_token % 2 + + pto.wait_flag("V", "MTE2", event_id=pingpong) + pto.mte_gm_ub( + pto.addptr(X, token_id * hidden_size), + pto.addptr(x_ub, pingpong * hidden_size), + 0, + hidden_size * f32_bytes, + nburst=(1, hidden_size * f32_bytes, hidden_size * f32_bytes), + ) + pto.set_flag("MTE2", "V", event_id=pingpong) + + pto.wait_flag("MTE2", "V", event_id=pingpong) + pto.wait_flag("MTE3", "V", event_id=pingpong) + rmsnorm_simt_token_body[threads, 1, 1]( + x_ub, + y_ub, + rstd_ub, + reduce_scratch, + w_ub, + eps, + pingpong, + threads=threads, + rounds=rounds, + lanes=lanes, + hidden_size=hidden_size, + ) + pto.set_flag("V", "MTE2", event_id=pingpong) + pto.set_flag("V", "MTE3", event_id=pingpong) + + pto.wait_flag("V", "MTE3", event_id=pingpong) + pto.mte_ub_gm( + pto.addptr(y_ub, pingpong * hidden_size), + pto.addptr(Y, token_id * hidden_size), + hidden_size * f32_bytes, + nburst=(1, hidden_size * f32_bytes, hidden_size * f32_bytes), + ) + + pto.mte_ub_gm( + pto.addptr(rstd_ub, pingpong * 8), + pto.addptr(RSTD, token_id), + f32_bytes, + nburst=(1, f32_bytes, f32_bytes), + ) + pto.set_flag("MTE3", "V", event_id=pingpong) + + pto.wait_flag("V", "MTE2", event_id=0) + pto.wait_flag("V", "MTE2", event_id=1) + pto.wait_flag("MTE3", "V", event_id=0) + pto.wait_flag("MTE3", "V", event_id=1) + + +def build_x128(): + return rmsnorm_4096_alloc_buffer_simt_context_kernel.compile( + threads=128, + rounds=8, + lanes=4, + tokens_per_core=64, + ) + + +def build_x64(): + return rmsnorm_4096_alloc_buffer_simt_context_kernel.compile( + threads=64, + rounds=16, + lanes=4, + tokens_per_core=64, + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Emit RMSNorm PTODSL MLIR") + parser.add_argument("--variant", choices=("x128", "x64"), default="x128") + args = parser.parse_args() + + compiled = build_x128() if args.variant == "x128" else build_x64() + compiled.verify() + print(compiled.mlir_text()) + + +if __name__ == "__main__": + main() diff --git a/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_launch_common.py b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_launch_common.py new file mode 100644 index 000000000..e7ca26dcc --- /dev/null +++ b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_launch_common.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Shared host-side setup for the RMSNorm alloc_buffer/SIMT launch examples.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import sys + +import numpy as np + + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from " + "rmsnorm_alloc_buffer_simt_launch_common.py" + ) + + +from rmsnorm_alloc_buffer_simt import rmsnorm_4096_alloc_buffer_simt_context_kernel + + +_DEVICE = "npu:0" +_HIDDEN_SIZE = 4096 +_THREADS = 128 +_ROUNDS = 8 +_LANES = 4 +_EPS = np.float32(1.0e-6) +_Y_GUARD_ELEMS = 1024 +_RSTD_GUARD_ELEMS = 64 +_SENTINEL = np.float32(123456.0) + + +@dataclass(frozen=True) +class Case: + name: str + n_cores: int + tokens_per_core: int + seed: int + rtol: float = 1.0e-4 + y_atol: float = 1.0e-4 + rstd_atol: float = 1.0e-5 + + @property + def tokens(self) -> int: + return self.n_cores * self.tokens_per_core + + +CASES = [ + Case("one_core_one_token", n_cores=1, tokens_per_core=1, seed=0x483001), + Case("one_core_four_tokens", n_cores=1, tokens_per_core=4, seed=0x483004), + Case("four_cores_two_tokens_each", n_cores=4, tokens_per_core=2, seed=0x483402), +] + +FULL_CASE = Case("full_64_cores_64_tokens_each", n_cores=64, tokens_per_core=64, seed=0x483640) + + +def init_runtime(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def make_inputs(case: Case) -> tuple[np.ndarray, np.ndarray]: + rng = np.random.RandomState(case.seed) + x = rng.uniform(-0.75, 0.75, size=(case.tokens, _HIDDEN_SIZE)).astype(np.float32) + w = rng.uniform(0.5, 1.5, size=(_HIDDEN_SIZE,)).astype(np.float32) + + # Make token/core addressing mistakes obvious in the output comparison. + token_offsets = (np.arange(case.tokens, dtype=np.float32)[:, None] * np.float32(0.001)) + x = (x + token_offsets).astype(np.float32) + return x, w + + +def rmsnorm_reference(x: np.ndarray, w: np.ndarray, eps: np.float32) -> tuple[np.ndarray, np.ndarray]: + sum_sq = np.sum(x * x, axis=1, dtype=np.float32) + rstd = (np.float32(1.0) / np.sqrt(sum_sq / np.float32(x.shape[1]) + eps)).astype(np.float32) + y = (x * rstd[:, None] * w[None, :]).astype(np.float32) + return y, rstd + + +def compile_kernel(case: Case): + return rmsnorm_4096_alloc_buffer_simt_context_kernel.compile( + threads=_THREADS, + rounds=_ROUNDS, + lanes=_LANES, + hidden_size=_HIDDEN_SIZE, + n_cores=case.n_cores, + tokens_per_core=case.tokens_per_core, + ) + + +def assert_guard_unchanged(name: str, guard: np.ndarray) -> None: + if not np.all(guard == _SENTINEL): + bad = np.nonzero(guard != _SENTINEL)[0] + first = int(bad[0]) + raise AssertionError( + f"{name} guard overwritten at guard index {first}: got {guard[first]!r}, expected {_SENTINEL!r}" + ) diff --git a/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_manual_launch.py b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_manual_launch.py new file mode 100644 index 000000000..eeb1a8719 --- /dev/null +++ b/ptodsl/examples/rms_norm/rmsnorm_alloc_buffer_simt_manual_launch.py @@ -0,0 +1,292 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Launch and validate the RMSNorm alloc_buffer/SIMT example with a hand-written +host wrapper that passes dynamic UB bytes explicitly. + +This is intentionally a bypass of PTODSL's ``compiled[grid, stream](...)`` +runtime launch path. The PTODSL kernel is still compiled to MLIR, then this +script builds a custom ``launch.cpp`` containing: + + kernel<<>>(...) + +Use it to validate the kernel in environments where the generated PTODSL +runtime wrapper does not yet consume ``dyn_shared_memory_buf``. +""" + +from __future__ import annotations + +import argparse +import ctypes +import hashlib +from pathlib import Path +import sys +import time + +import numpy as np + + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from " + "rmsnorm_alloc_buffer_simt_manual_launch.py" + ) + + +from ptodsl._runtime.native_build import ( # noqa: E402 + _compile_launch_cpp, + _effective_insert_sync, + _link_shared_library, + _run_ptoas, +) + +import rmsnorm_alloc_buffer_simt_launch_common as launch_common # noqa: E402 +from rmsnorm_alloc_buffer_simt_launch_common import ( # noqa: E402 + _DEVICE, + _EPS, + _HIDDEN_SIZE, + _RSTD_GUARD_ELEMS, + _SENTINEL, + _Y_GUARD_ELEMS, + CASES, + FULL_CASE, + Case, + assert_guard_unchanged, + compile_kernel, + init_runtime, + make_inputs, + npu_stream, + rmsnorm_reference, +) + + +_DYN_SHARED_BYTES = 82496 + + +def _manual_launch_cpp(*, ir_function_name: str, launch_symbol: str, dyn_shared_bytes: int) -> str: + return f"""#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void {ir_function_name}( + __gm__ float *X, + __gm__ float *Y, + __gm__ float *W, + __gm__ float *RSTD, + float eps); + +extern "C" void {launch_symbol}( + uint32_t grid, + void *stream, + float *X, + float *Y, + float *W, + float *RSTD, + float eps) {{ + constexpr uint32_t dynSharedBytes = {int(dyn_shared_bytes)}; + {ir_function_name}<<>>( + (__gm__ float *)X, + (__gm__ float *)Y, + (__gm__ float *)W, + (__gm__ float *)RSTD, + eps); +}} +""" + + +def _manual_cache_dir(compiled, launch_cpp_text: str) -> Path: + payload = "\n".join([ + compiled.mlir_text(), + launch_cpp_text, + repr(compiled.specialization_key), + ]).encode("utf-8") + digest = hashlib.sha256(payload).hexdigest()[:16] + return Path.home() / ".cache" / "ptodsl" / f"{compiled._py_name}_manual_dynub_{digest}" + + +def build_manual_library(compiled, *, dyn_shared_bytes: int = _DYN_SHARED_BYTES) -> tuple[Path, str]: + module_spec = compiled._module_spec + ir_function_name = module_spec.function_name + launch_symbol = f"ptodsl_manual_launch_{ir_function_name}" + + declared_dyn_shared = getattr(module_spec, "dyn_shared_memory_buf", None) + if declared_dyn_shared != dyn_shared_bytes: + raise RuntimeError( + f"expected @pto.jit dyn_shared_memory_buf={dyn_shared_bytes}, " + f"got {declared_dyn_shared!r}" + ) + + launch_cpp_text = _manual_launch_cpp( + ir_function_name=ir_function_name, + launch_symbol=launch_symbol, + dyn_shared_bytes=dyn_shared_bytes, + ) + cache_dir = _manual_cache_dir(compiled, launch_cpp_text) + mlir_path = cache_dir / "kernel.mlir" + kernel_object = cache_dir / "kernel.o" + launch_cpp = cache_dir / "manual_launch.cpp" + launch_object = cache_dir / "manual_launch.o" + shared_library = cache_dir / f"lib{ir_function_name}_manual_dynub.so" + + if shared_library.is_file(): + return shared_library, launch_symbol + + cache_dir.mkdir(parents=True, exist_ok=True) + mlir_path.write_text(compiled.mlir_text(), encoding="utf-8") + launch_cpp.write_text(launch_cpp_text, encoding="utf-8") + + _run_ptoas( + mlir_path, + kernel_object, + target_arch=module_spec.target_arch, + insert_sync=_effective_insert_sync( + mode=module_spec.mode, + insert_sync=module_spec.insert_sync, + ), + ) + _compile_launch_cpp( + launch_cpp, + launch_object, + kernel_kind=module_spec.kernel_kind, + export_macro=f"{ir_function_name}_EXPORTS", + ) + _link_shared_library( + launch_object, + kernel_object, + shared_library, + kernel_kind=module_spec.kernel_kind, + ) + return shared_library, launch_symbol + + +def _manual_launch(compiled, *, grid: int, stream, x_ptr: int, y_ptr: int, w_ptr: int, rstd_ptr: int, eps: float): + lib_path, launch_symbol = build_manual_library(compiled) + lib = ctypes.CDLL(str(lib_path)) + launch = getattr(lib, launch_symbol) + launch.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_float, + ] + launch.restype = None + launch( + ctypes.c_uint32(grid), + ctypes.c_void_p(int(getattr(stream, "value", stream))), + ctypes.c_void_p(x_ptr), + ctypes.c_void_p(y_ptr), + ctypes.c_void_p(w_ptr), + ctypes.c_void_p(rstd_ptr), + ctypes.c_float(eps), + ) + + +def run_case_manual(case: Case, torch) -> None: + x, w = make_inputs(case) + y_ref, rstd_ref = rmsnorm_reference(x, w, _EPS) + + x_t = torch.from_numpy(x).to(_DEVICE) + w_t = torch.from_numpy(w).to(_DEVICE) + + y_storage = torch.full( + (case.tokens * _HIDDEN_SIZE + _Y_GUARD_ELEMS,), + float(_SENTINEL), + dtype=torch.float32, + device=_DEVICE, + ) + rstd_storage = torch.full( + (case.tokens + _RSTD_GUARD_ELEMS,), + float(_SENTINEL), + dtype=torch.float32, + device=_DEVICE, + ) + + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = compile_kernel(case) + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + _manual_launch( + compiled, + grid=case.n_cores, + stream=stream, + x_ptr=x_t.data_ptr(), + y_ptr=y_storage.data_ptr(), + w_ptr=w_t.data_ptr(), + rstd_ptr=rstd_storage.data_ptr(), + eps=float(_EPS), + ) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + y_out = y_storage[: case.tokens * _HIDDEN_SIZE].cpu().numpy().reshape(case.tokens, _HIDDEN_SIZE) + rstd_out = rstd_storage[: case.tokens].cpu().numpy() + y_guard = y_storage[case.tokens * _HIDDEN_SIZE :].cpu().numpy() + rstd_guard = rstd_storage[case.tokens :].cpu().numpy() + + np.testing.assert_allclose(rstd_out, rstd_ref, rtol=case.rtol, atol=case.rstd_atol) + np.testing.assert_allclose(y_out, y_ref, rtol=case.rtol, atol=case.y_atol) + assert_guard_unchanged("Y", y_guard) + assert_guard_unchanged("RSTD", rstd_guard) + + y_diff = float(np.max(np.abs(y_out - y_ref))) if y_out.size else 0.0 + rstd_diff = float(np.max(np.abs(rstd_out - rstd_ref))) if rstd_out.size else 0.0 + simt_config = getattr(case, "simt_config", "threads=128 rounds=8 lanes=4") + print( + f"PASS {case.name} manual-dynub " + f"grid={case.n_cores} tokens={case.tokens} {simt_config} " + f"dynSharedBytes={_DYN_SHARED_BYTES} " + f"compile={compile_s:.3f}s launch={launch_s:.3f}s " + f"max|Y|={y_diff:.3e} max|RSTD|={rstd_diff:.3e}" + ) + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--device", default=_DEVICE, help="torch NPU device, default: npu:0") + parser.add_argument( + "--case", + choices=[case.name for case in CASES] + [FULL_CASE.name, "all"], + default="all", + ) + parser.add_argument("--include-full", action="store_true", help="include the 64-core x 64-token full case") + args = parser.parse_args(argv) + + launch_common._DEVICE = args.device + globals()["_DEVICE"] = args.device + + selected = list(CASES) + if args.include_full: + selected.append(FULL_CASE) + if args.case != "all": + all_cases = {case.name: case for case in selected + [FULL_CASE]} + selected = [all_cases[args.case]] + + torch = init_runtime() + for case in selected: + run_case_manual(case, torch) + print("All RMSNorm manual dynamic-UB cases passed.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/ptodsl/_allreduce.py b/ptodsl/ptodsl/_allreduce.py new file mode 100644 index 000000000..3c6637bdb --- /dev/null +++ b/ptodsl/ptodsl/_allreduce.py @@ -0,0 +1,355 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +SIMT cross-workitem all-reduce. + +All-reduce ops are emitted **inline** at the current insertion point. +Three reducer variants: ``simt_allreduce_sum``, ``simt_allreduce_max``, ``simt_allreduce_min``. + +Dispatch tree (compile-time, since *threads* / *scale* are Python ints):: + + threads <= scale → identity + threads ≤ 32, pow2(threads), pow2(scale) → warp_reduce + threads ≤ 32 → ub_reduce + threads > 32, pow2(threads), scale≤32, pow2(scale) → cross_warp_reduce + otherwise → ub_reduce (fallback) +""" + +from __future__ import annotations + +from . import scalar +from ._control_flow import if_, for_ +from ._ops import const as _const, get_laneid, get_tid_x, redux_add, redux_max, redux_min, shuffle_bfly, syncthreads +from ._surface_values import unwrap_surface_value +from ._types import _resolve, float16 as _f16_dtype, float32 as _f32_dtype + +from mlir.ir import F16Type, F32Type + + +# ── helpers ──────────────────────────────────────────────────────────────────── + +def _is_pow2(n: int) -> bool: + """Compile-time power-of-two check.""" + return n > 0 and (n & (n - 1)) == 0 + + +# ── reducer dispatch tables ──────────────────────────────────────────────────── + +_REDUCER_IDENTITY = { + "sum": {"f32": 0.0, "f16": 0.0}, + "max": {"f32": float("-inf"), "f16": float("-inf")}, + "min": {"f32": float("inf"), "f16": float("inf")}, +} + +_REDUCER_COMBINE = { + "sum": lambda a, b: a + b, + "max": scalar.max, + "min": scalar.min, +} + +_REDUCER_REDUX = { + "sum": redux_add, + "max": redux_max, + "min": redux_min, +} + + +# ── butterfly ────────────────────────────────────────────────────────────────── + +def _emit_butterfly(v, *, threads: int, scale: int, reducer: str): + """Unrolled butterfly shuffle reduce.""" + combine = _REDUCER_COMBINE[reducer] + cur = threads + while cur > scale: + offset = cur // 2 + v = combine(v, shuffle_bfly(v, offset)) + cur //= 2 + return v + + +# ── warp_hw_reduce ──────────────────────────────────────────────────────────── + +def _emit_warp_hw_reduce(x, *, threads: int, lane_in_warp, dtype: str, reducer: str): + """Warp-level hardware reduce with group masking.""" + redux_fn = _REDUCER_REDUX[reducer] + groups = 32 // threads + + if groups == 1: + return redux_fn(x) + + c_identity = _const( + _REDUCER_IDENTITY[reducer][dtype], + dtype=_resolve(_f32_dtype if dtype == "f32" else _f16_dtype), + ) + my_group = lane_in_warp // threads + + for g in range(groups): + in_group = my_group == g + masked = scalar.select(in_group, x, c_identity) + reduced = redux_fn(masked) + x = scalar.select(in_group, reduced, x) + return x + + +# ── warp_reduce ─────────────────────────────────────────────────────────────── + +def _emit_warp_reduce(x, *, + dtype, threads, scale, thread_offset, reducer): + """Single-warp all-reduce.""" + extent = threads // scale + if extent <= 1: + return x + + if thread_offset: + lane_in_warp = (get_tid_x() - thread_offset) & 31 + else: + lane_in_warp = get_laneid() + + if extent >= 16 and scale == 1: + return _emit_warp_hw_reduce( + x, threads=threads, + lane_in_warp=lane_in_warp, dtype=dtype, reducer=reducer, + ) + return _emit_butterfly(x, threads=threads, scale=scale, reducer=reducer) + + +# ── cross_warp_reduce ───────────────────────────────────────────────────────── + +def _emit_cross_warp_reduce(x, scratch, *, + dtype, threads, scale, thread_offset, reducer): + """Cross-warp all-reduce (threads > 32).""" + num_warps = threads // 32 + c_identity = _const( + _REDUCER_IDENTITY[reducer][dtype], + dtype=_resolve(_f32_dtype if dtype == "f32" else _f16_dtype), + ) + combine = _REDUCER_COMBINE[reducer] + redux_fn = _REDUCER_REDUX[reducer] + + # ── thread indexing ────────────────────────────────────────────────── + tid_x = get_tid_x() + if thread_offset: + tx = tid_x - thread_offset + wid = tx // 32 + lid = tx & 31 + else: + tx = tid_x + wid = tx // 32 + lid = get_laneid() + + # ── per-warp reduce ────────────────────────────────────────────────── + if scale == 1: + warp_val = redux_fn(x) + else: + warp_val = _emit_butterfly(x, threads=32, scale=scale, reducer=reducer) + + # ── warp leaders write partial results ─────────────────────────────── + is_writer = lid < scale + with if_(is_writer) as br: + with br.then_: + slot = wid * scale + lid + scalar.store(warp_val, scratch, scalar.index_cast(slot)) + + syncthreads() + + # ── leader warp reduces partial sums ───────────────────────────────── + is_leader_warp = tx < 32 + with if_(is_leader_warp) as br: + with br.then_: + if scale == 1: + need_load = lid < num_warps + with if_(need_load) as inner_br: + with inner_br.then_: + tmp = scalar.load(scratch, scalar.index_cast(lid)) + inner_br.assign(loaded=tmp) + with inner_br.else_: + inner_br.assign(loaded=c_identity) + loaded = inner_br.loaded + stage4_result = redux_fn(loaded) + elif scale * num_warps <= 32: + total = scale * num_warps + need_load = lid < total + with if_(need_load) as inner_br: + with inner_br.then_: + tmp = scalar.load(scratch, scalar.index_cast(lid)) + inner_br.assign(loaded=tmp) + with inner_br.else_: + inner_br.assign(loaded=c_identity) + loaded = inner_br.loaded + stage4_result = _emit_butterfly( + loaded, threads=total, scale=scale, reducer=reducer, + ) + else: + is_reducer = lid < scale + reduced = c_identity + my_slot = lid % scale + for w in range(num_warps): + idx_val = w * scale + my_slot + loaded_v = scalar.load(scratch, scalar.index_cast(idx_val)) + reduced = combine(reduced, loaded_v) + stage4_result = scalar.select(is_reducer, reduced, c_identity) + + br.assign(stage4_result=stage4_result) + with br.else_: + br.assign(stage4_result=c_identity) + + partial_reduced = br.stage4_result + + # ── global leader writes result ────────────────────────────────────── + is_global_leader = tx < scale + with if_(is_global_leader) as br5: + with br5.then_: + scalar.store(partial_reduced, scratch, scalar.index_cast(tx)) + + # ── broadcast ──────────────────────────────────────────────────────── + syncthreads() + result = scalar.load(scratch, scalar.index_cast(tx % scale)) + syncthreads() + + return result + + +# ── ub_reduce ───────────────────────────────────────────────────────────────── + +def _emit_ub_reduce(x, scratch, *, + dtype, threads, scale, thread_offset, reducer): + """UB-scratch all-reduce (fallback for non-pow2 or general case).""" + combine = _REDUCER_COMBINE[reducer] + + # ── thread indexing ────────────────────────────────────────────────── + tid_x = get_tid_x() + tx = (tid_x - thread_offset) if thread_offset else tid_x + group = tx // threads + lane = tx % threads + + # ── each lane writes x → scratch[tx] ───────────────────────────────── + scalar.store(x, scratch, scalar.index_cast(tx)) + syncthreads() + + # ── reducers sequentially combine ──────────────────────────────────── + is_reducer = lane < scale + with if_(is_reducer) as br: + with br.then_: + group_offset = group * threads + first_elem = group_offset + lane + acc = scalar.load(scratch, scalar.index_cast(first_elem)) + + carry_loop = for_(scale, threads, step=scale).carry(acc=acc) + with carry_loop: + prev = carry_loop.acc + elem = first_elem + carry_loop.iv + loaded = scalar.load(scratch, elem) + carry_loop.update(acc=combine(prev, loaded)) + acc = carry_loop.final("acc") + + br.assign(flag=acc) + with br.else_: + br.assign(flag=x) + + flag = br.flag + syncthreads() + + # ── per-class leader writes back ───────────────────────────────────── + is_leader = lane < scale + with if_(is_leader) as br5: + with br5.then_: + scalar.store(flag, scratch, scalar.index_cast(group * threads + lane)) + + # ── broadcast ──────────────────────────────────────────────────────── + syncthreads() + result = scalar.load(scratch, scalar.index_cast(group * threads + (tx % scale))) + syncthreads() + + return result + + +# ── public API ──────────────────────────────────────────────────────────────── + +def _check_params(*, threads, scale, thread_offset): + """Validate allreduce parameters (compile-time checks).""" + for name, val in (("threads", threads), ("scale", scale), + ("thread_offset", thread_offset)): + if not isinstance(val, int): + raise ValueError( + f"all_reduce: '{name}' must be a Python int, " + f"got {type(val).__name__}" + ) + if threads < 1: + raise ValueError(f"all_reduce: threads must be >= 1, got {threads}") + if scale < 1: + raise ValueError(f"all_reduce: scale must be >= 1, got {scale}") + if thread_offset < 0: + raise ValueError( + f"all_reduce: thread_offset must be >= 0, got {thread_offset}" + ) + if threads % scale != 0: + raise ValueError( + f"all_reduce requires threads % scale == 0; " + f"got threads={threads}, scale={scale}" + ) + + +def _simt_allreduce(value, *, threads, scale, thread_offset, scratch, reducer): + """Unified allreduce dispatch tree.""" + _check_params(threads=threads, scale=scale, thread_offset=thread_offset) + + if threads <= scale: + return value + + raw_value = unwrap_surface_value(value) + if raw_value.type == F32Type.get(): + dtype = "f32" + elif raw_value.type == F16Type.get(): + dtype = "f16" + else: + raise NotImplementedError(f"all_reduce: unsupported dtype {raw_value.type}") + + args = dict(dtype=dtype, threads=threads, scale=scale, + thread_offset=thread_offset, reducer=reducer) + + if threads <= 32 and _is_pow2(threads) and _is_pow2(scale): + return _emit_warp_reduce(value, **args) + + if scratch is None: + raise ValueError( + f"all_reduce {reducer}/{dtype}/t{threads}/s{scale}/o{thread_offset} " + "requires a UB scratch buffer" + ) + + if threads <= 32: + return _emit_ub_reduce(value, scratch, **args) + + if scale <= 32 and _is_pow2(threads) and _is_pow2(scale): + return _emit_cross_warp_reduce(value, scratch, **args) + + return _emit_ub_reduce(value, scratch, **args) + + +def simt_allreduce_sum(value, *, threads, scale=1, thread_offset=0, scratch=None): + """Sum reduce across SIMT work-items.""" + return _simt_allreduce(value, threads=threads, scale=scale, + thread_offset=thread_offset, scratch=scratch, reducer="sum") + + +def simt_allreduce_max(value, *, threads, scale=1, thread_offset=0, scratch=None): + """Max reduce across SIMT work-items.""" + return _simt_allreduce(value, threads=threads, scale=scale, + thread_offset=thread_offset, scratch=scratch, reducer="max") + + +def simt_allreduce_min(value, *, threads, scale=1, thread_offset=0, scratch=None): + """Min reduce across SIMT work-items.""" + return _simt_allreduce(value, threads=threads, scale=scale, + thread_offset=thread_offset, scratch=scratch, reducer="min") + + +__all__ = [ + "simt_allreduce_sum", + "simt_allreduce_max", + "simt_allreduce_min", +] diff --git a/ptodsl/ptodsl/_bootstrap.py b/ptodsl/ptodsl/_bootstrap.py index 958494fd0..639f78326 100644 --- a/ptodsl/ptodsl/_bootstrap.py +++ b/ptodsl/ptodsl/_bootstrap.py @@ -61,6 +61,10 @@ def _bootstrap_python_paths() -> None: _bootstrap_python_paths() from mlir.dialects import pto as _pto_dialect # noqa: E402 +try: + from mlir.dialects import llvm as _llvm_dialect # noqa: E402 +except Exception: # pragma: no cover - depends on the installed MLIR package. + _llvm_dialect = None from mlir.ir import Context, Location # noqa: E402 @@ -68,6 +72,8 @@ def make_context() -> Context: """Create a fresh MLIR Context with the PTO dialect loaded.""" ctx = Context() _pto_dialect.register_dialect(ctx, load=True) + if _llvm_dialect is not None and hasattr(_llvm_dialect, "register_dialect"): + _llvm_dialect.register_dialect(ctx, load=True) return ctx diff --git a/ptodsl/ptodsl/_builtin_vector.py b/ptodsl/ptodsl/_builtin_vector.py new file mode 100644 index 000000000..decc8be75 --- /dev/null +++ b/ptodsl/ptodsl/_builtin_vector.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Builtin MLIR vector helpers for PTODSL scalar-contiguous access.""" + +from ._bootstrap import make_context # ensure MLIR is on sys.path # noqa: F401 +from ._scalar_coercion import coerce_scalar_to_type +from ._surface_values import VecValue, unwrap_surface_value +from ._types import _resolve, _validate_vec_lanes, vec_type + +from mlir.dialects import arith +from mlir.dialects import llvm +from mlir.ir import IntegerType, VectorType + + +def vec(dtype, lanes: int, *, init=None): + """Create a builtin vector type descriptor or broadcast vector value.""" + lanes = _validate_vec_lanes(lanes, context="pto.vec(...)") + descriptor = vec_type(dtype, lanes) + if init is None: + return descriptor + return _broadcast_vec_value(descriptor, init) + + +def _broadcast_vec_value(descriptor, init): + vector_type = _resolve(descriptor) + element_type = VectorType(vector_type).element_type + raw_init = unwrap_surface_value(init) + + if hasattr(raw_init, "type") and VectorType.isinstance(raw_init.type): + vec_value = VecValue(raw_init) + if vec_value.type != vector_type: + raise TypeError(f"pto.vec(..., init=vector) expected {vector_type}, got {vec_value.type}") + return vec_value + + scalar_value = coerce_scalar_to_type(init, element_type, context="pto.vec(..., init=...)") + current = llvm.UndefOp(vector_type).res + i32 = IntegerType.get_signless(32) + for lane in range(descriptor.lanes): + lane_index = arith.ConstantOp(i32, lane).result + current = llvm.InsertElementOp(current, scalar_value, lane_index).res + return VecValue(current) + + +__all__ = ["vec"] diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index 5c23bc97f..dbfe9756c 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -76,6 +76,16 @@ def _normalize_backend(backend: str, *, fn=None) -> str: return backend +def _normalize_dyn_shared_memory_buf(value): + if value is None: + return None + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("@pto.jit dyn_shared_memory_buf must be a non-negative integer byte count") + if value < 0: + raise ValueError("@pto.jit dyn_shared_memory_buf must be non-negative") + return value + + def _module_attr_map(module): attrs = module.operation.attributes return {name: str(attrs[name]) for name in _MODULE_ATTRS if name in attrs} @@ -167,6 +177,7 @@ def jit( entry: bool = True, mode: str = "auto", insert_sync: bool | None = None, + dyn_shared_memory_buf: int | None = None, ast_rewrite: bool | None = None, frontend_options: Mapping | None = None, ): @@ -187,6 +198,9 @@ def jit( insert_sync: ``True``/``False`` to explicitly control PTOAS sync insertion for launch builds. ``None`` keeps the mode-based default behavior. + dyn_shared_memory_buf: + Dynamic UB scratch byte count to attach to the entry function + and pass to native launch code. ast_rewrite: ``True`` enables AST rewriting of Python ``if`` / ``for range(...)`` into device-side PTODSL control flow. @@ -208,6 +222,7 @@ def jit( ast_rewrite=ast_rewrite, frontend_options=frontend_options, ) + normalized_dyn_shared_memory_buf = _normalize_dyn_shared_memory_buf(dyn_shared_memory_buf) def decorator(fn): fn_name = name or fn.__name__ @@ -229,6 +244,7 @@ def decorator(fn): entry=entry, mode=normalized_mode, insert_sync=insert_sync, + dyn_shared_memory_buf=normalized_dyn_shared_memory_buf, module_style=ModuleStyle.BACKEND_PARTITIONED, source_file=source_file, source_line=getattr(fn.__code__, "co_firstlineno", None), diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 53dd3e225..c232dccec 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -41,6 +41,7 @@ ) from ._runtime_scalar_ops import emit_runtime_binary_op from ._surface_values import ( + AllocatedBufferValue, MaskResultValue, PartitionTensorViewValue, TensorViewValue, @@ -65,6 +66,7 @@ mask_type, part_tensor_view_type, part_tensor_view_type_from_dims, + ptr, tensor_view_type, tensor_view_type_from_dims, vreg_type, @@ -82,7 +84,9 @@ IndexType, IntegerType, MemRefType, + Operation, Type, + TypeAttr, ) # Pipe name shorthands → canonical PIPE_* names @@ -2315,6 +2319,89 @@ def _tile_transfer_partition(tv, tile, *, offsets=None, sizes=None, context: str return partition_view(tv, offsets=normalized_offsets, sizes=normalized_sizes) +def alloc_buffer(shape, dtype, **kwargs): + """ + Allocate SIMT lane-local scratch storage and return an address-like value. + + The allocation emits an LLVM stack allocation in the surrounding SIMT + helper. UB scratch uses explicit ``pto.castptr`` / ``pto.addptr`` pointer + authoring and ``@pto.jit(dyn_shared_memory_buf=...)`` launch metadata. + """ + if kwargs: + unexpected = ", ".join(sorted(kwargs)) + raise TypeError( + f"pto.alloc_buffer(...) does not accept keyword argument(s): {unexpected}. " + "It only allocates SIMT local buffers; author UB scratch explicitly with " + "pto.castptr/pto.addptr and @pto.jit(dyn_shared_memory_buf=...)." + ) + _require_explicit_mode("pto.alloc_buffer(...)") + element_type = _resolve(dtype) + element_count = _static_alloc_buffer_element_count(shape) + elem_bytes = _element_bytewidth(element_type) + byte_size = element_count * elem_bytes + + return _alloc_local_buffer( + shape, + dtype, + element_type, + element_count, + byte_size, + ) + + +def _static_alloc_buffer_element_count(shape): + if isinstance(shape, int): + dims = (shape,) + elif isinstance(shape, (list, tuple)): + dims = tuple(shape) + else: + raise TypeError("pto.alloc_buffer(shape, ...) expects an int or a tuple/list of static dimensions") + if not dims: + raise ValueError("pto.alloc_buffer(shape, ...) expects at least one dimension") + count = 1 + for dim in dims: + raw_dim = unwrap_surface_value(dim) + if isinstance(raw_dim, bool): + raise TypeError("pto.alloc_buffer(shape, ...) does not accept bool dimensions") + if not isinstance(raw_dim, int): + raise TypeError( + "pto.alloc_buffer(shape, ...) requires static integer dimensions; " + f"got {getattr(raw_dim, 'type', type(raw_dim).__name__)}" + ) + if raw_dim <= 0: + raise ValueError(f"pto.alloc_buffer(shape, ...) dimensions must be positive, got {raw_dim}") + count *= raw_dim + return count + + +def _alloc_local_buffer(shape, dtype, element_type, element_count, byte_size): + i32 = IntegerType.get_signless(32) + count = _materialize_integer_literal(i32, element_count) + llvm_ptr_type = Type.parse("!llvm.ptr") + alloca = Operation.create( + "llvm.alloca", + results=[llvm_ptr_type], + operands=[count], + attributes={ + "elem_type": TypeAttr.get(element_type), + }, + ).results[0] + return AllocatedBufferValue( + alloca, + shape=_normalize_alloc_buffer_shape_metadata(shape), + dtype=dtype, + element_type=element_type, + element_count=element_count, + byte_size=byte_size, + ) + + +def _normalize_alloc_buffer_shape_metadata(shape): + if isinstance(shape, int): + return (shape,) + return tuple(unwrap_surface_value(dim) for dim in shape) + + def alloc_tile( tile_type=None, *, @@ -5464,7 +5551,7 @@ def import_reserved_buffer(name, *, peer_func): "vaxpy", "vaddrelu", "vsubrelu", "vsel", "make_tensor_view", "partition_view", - "alloc_tile", + "alloc_buffer", "alloc_tile", "tload", "tstore", "tmov", "tinsert", "tmatmul", "tmatmul_acc", "tmatmul_mx", "tmatmul_mx_acc", "tmatmul_mx_bias", "tgemv_mx", "tgemv_mx_acc", "tgemv_mx_bias", diff --git a/ptodsl/ptodsl/_scalar_adaptation.py b/ptodsl/ptodsl/_scalar_adaptation.py index 09ec59828..6f17f41bd 100644 --- a/ptodsl/ptodsl/_scalar_adaptation.py +++ b/ptodsl/ptodsl/_scalar_adaptation.py @@ -18,7 +18,7 @@ ) from mlir.dialects import arith -from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType, VectorType def classify_runtime_scalar_type(type_obj): @@ -28,6 +28,10 @@ def classify_runtime_scalar_type(type_obj): return "integer" if any(cls.isinstance(type_obj) for cls in (BF16Type, F16Type, F32Type)): return "float" + if VectorType.isinstance(type_obj): + elem_type = VectorType(type_obj).element_type + if any(cls.isinstance(elem_type) for cls in (BF16Type, F16Type, F32Type)): + return "float" raise TypeError(f"runtime scalar operators only support index/int/float values, got {type_obj}") diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py index 177717707..0f080b500 100644 --- a/ptodsl/ptodsl/_subkernels.py +++ b/ptodsl/ptodsl/_subkernels.py @@ -405,6 +405,7 @@ def __init__( ast_rewrite: bool = True, simt_max_threads: int | None = None, simt_max_regs: int | None = None, + simt_inline_dims: tuple | None = None, ): self._role = role self._name = name @@ -412,9 +413,12 @@ def __init__( self._ast_rewrite = ast_rewrite self._simt_max_threads = simt_max_threads self._simt_max_regs = simt_max_regs + self._simt_inline_dims = simt_inline_dims self._session_cm = None def __call__(self, fn): + if self._simt_inline_dims is not None: + raise TypeError("pto.simt(dim_x, dim_y, dim_z) is only supported as an inline context manager") return SubkernelTemplate( SubkernelSpec( role=self._role, @@ -446,6 +450,7 @@ def __enter__(self): self._role.value, symbol_name, self._target, + simt_launch_dims=self._simt_inline_dims, ) self._session_cm.__enter__() return None @@ -465,6 +470,7 @@ def _subkernel_decorator( ast_rewrite: bool = True, simt_max_threads: int | None = None, simt_max_regs: int | None = None, + simt_inline_dims: tuple | None = None, ): return _SubkernelSurface( role, @@ -473,6 +479,7 @@ def _subkernel_decorator( ast_rewrite=ast_rewrite, simt_max_threads=simt_max_threads, simt_max_regs=simt_max_regs, + simt_inline_dims=simt_inline_dims, ) @@ -485,6 +492,7 @@ def _decorate_subkernel( ast_rewrite: bool = True, simt_max_threads: int | None = None, simt_max_regs: int | None = None, + simt_inline_dims: tuple | None = None, ): if fn is not None: return _subkernel_decorator( @@ -494,6 +502,7 @@ def _decorate_subkernel( ast_rewrite=ast_rewrite, simt_max_threads=simt_max_threads, simt_max_regs=simt_max_regs, + simt_inline_dims=simt_inline_dims, )(fn) return _subkernel_decorator( role, @@ -502,6 +511,7 @@ def _decorate_subkernel( ast_rewrite=ast_rewrite, simt_max_threads=simt_max_threads, simt_max_regs=simt_max_regs, + simt_inline_dims=simt_inline_dims, ) @@ -527,7 +537,7 @@ def _validate_simt_resource_attr(name: str, value: int | None) -> int | None: def simt( fn=None, - *, + *dims, name: str | None = None, target: str = "a5", ast_rewrite: bool = True, @@ -536,6 +546,14 @@ def simt( ): max_threads = _validate_simt_resource_attr("max_threads", max_threads) max_regs = _validate_simt_resource_attr("max_regs", max_regs) + simt_inline_dims = None + if fn is not None and not callable(fn): + dims = (fn, *dims) + fn = None + if dims: + if len(dims) != 3: + raise TypeError("pto.simt(dim_x, dim_y, dim_z) expects exactly three launch dimensions") + simt_inline_dims = tuple(dims) return _decorate_subkernel( KernelRole.SIMT, fn, @@ -544,6 +562,7 @@ def simt( ast_rewrite=ast_rewrite, simt_max_threads=max_threads, simt_max_regs=max_regs, + simt_inline_dims=simt_inline_dims, ) diff --git a/ptodsl/ptodsl/_surface_values.py b/ptodsl/ptodsl/_surface_values.py index 50b9fa123..e829ad5cf 100644 --- a/ptodsl/ptodsl/_surface_values.py +++ b/ptodsl/ptodsl/_surface_values.py @@ -13,7 +13,12 @@ from dataclasses import dataclass from ._diagnostics import native_python_control_flow_error -from ._runtime_scalar_ops import emit_runtime_binary_op, emit_runtime_bitwise_op, emit_runtime_compare +from ._runtime_scalar_ops import ( + emit_runtime_binary_op, + emit_runtime_bitwise_op, + emit_runtime_compare, + normalize_runtime_binary_operands, +) from ._scalar_adaptation import coerce_runtime_index_value from ._surface_types import PartitionTensorView, TensorView, Tile from ._types import _normalize_address_space, _resolve, ptr @@ -21,7 +26,7 @@ from mlir.dialects import arith from mlir.dialects import memref from mlir.dialects import pto as _pto -from mlir.ir import IndexType, IntegerAttr, MemRefType, ShapedType, StridedLayoutAttr, Type +from mlir.ir import IndexType, IntegerAttr, IntegerType, MemRefType, ShapedType, StridedLayoutAttr, Type, VectorType def _validate_surface_value_access(value): @@ -172,6 +177,8 @@ def wrap_surface_value( return AddressValue(value) except Exception: pass + if VectorType.isinstance(type_obj): + return VecValue(value) return RuntimeValue(value) @@ -285,6 +292,37 @@ def __rxor__(self, other): return wrap_surface_value(emit_runtime_bitwise_op("xor", unwrap_surface_value(other), self.value)) +class VecValue(_SurfaceValue): + """Author-facing builtin vector value backed by an MLIR vector SSA value.""" + + def __init__(self, value): + if not VectorType.isinstance(value.type): + raise TypeError(f"VecValue expects an MLIR vector value, got {value.type}") + super().__init__(value) + vec_type = VectorType(value.type) + if vec_type.rank != 1: + raise TypeError(f"PTODSL builtin vectors must be rank-1, got {value.type}") + self.lanes = int(vec_type.shape[0]) + self.element_type = vec_type.element_type + + def __mul__(self, other): + return _emit_vec_binary_op("mul", self, other) + + def __rmul__(self, other): + return _emit_vec_binary_op("mul", other, self) + + +def _emit_vec_binary_op(op_name: str, lhs, rhs): + lhs_raw = unwrap_surface_value(lhs) + rhs_raw = unwrap_surface_value(rhs) + if not (VectorType.isinstance(lhs_raw.type) and VectorType.isinstance(rhs_raw.type)): + raise TypeError("PTODSL VecValue arithmetic expects compatible vector operands") + lhs_raw, rhs_raw, kind = normalize_runtime_binary_operands(lhs_raw, rhs_raw) + if kind != "float": + raise TypeError(f"PTODSL VecValue operator '{op_name}' currently supports only floating-point vectors") + return VecValue(emit_runtime_binary_op(op_name, lhs_raw, rhs_raw)) + + class MaskResultValue(_SurfaceValue): """Mask value that also supports `(mask, remained)` unpacking.""" @@ -307,6 +345,37 @@ def __radd__(self, offset): return AddressOffsetValue(self, offset) +class AllocatedBufferValue(AddressValue): + """Address returned by ``pto.alloc_buffer`` with allocation metadata.""" + + def __init__( + self, + value, + *, + shape, + dtype, + element_type, + element_count, + byte_size, + ): + super().__init__(value) + self.shape = tuple(shape) + self.dtype = dtype + self.element_type = element_type + self.element_count = element_count + self.byte_size = byte_size + + @property + def surface_metadata(self): + return { + "shape": self.shape, + "dtype": self.dtype, + "element_type": self.element_type, + "element_count": self.element_count, + "byte_size": self.byte_size, + } + + @dataclass(frozen=True) class AddressOffsetValue: """Address view plus an element offset, used by scalar.load/store sugar.""" @@ -991,12 +1060,14 @@ def _coerce_index_value(value): __all__ = [ + "AllocatedBufferValue", "AddressOffsetValue", "AddressValue", "MaskResultValue", "PartitionSpec", "PartitionTensorViewValue", "RuntimeValue", + "VecValue", "TileElementRef", "TileSliceValue", "TensorViewValue", diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index f15f62e39..715f45b08 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -35,6 +35,7 @@ class KernelModuleSpec: entry: bool = True mode: str = "auto" insert_sync: bool | None = None + dyn_shared_memory_buf: int | None = None module_style: ModuleStyle = ModuleStyle.NESTED source_file: str | None = None source_line: int | None = None diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index c2390f1d6..0738afcc9 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -119,6 +119,7 @@ class InlineSubkernelOutlineFrame: owner_symbol_name: str wrapper_op: object body_block: object + simt_launch_dims: tuple | None = None class TraceSession: @@ -289,7 +290,14 @@ def enter_subkernel_body(self, role: str, symbol_name: str, target: str): raise RuntimeError("PTODSL trace-session subkernel stack corruption detected") @contextmanager - def enter_inline_subkernel(self, role: str, symbol_name: str, target: str): + def enter_inline_subkernel( + self, + role: str, + symbol_name: str, + target: str, + *, + simt_launch_dims: tuple | None = None, + ): """Capture one inline subkernel scope and outline it into a helper on exit.""" frame = SubkernelTraceFrame( role=role, @@ -303,6 +311,7 @@ def enter_inline_subkernel(self, role: str, symbol_name: str, target: str): owner_symbol_name=self.current_function_owner_symbol_name, wrapper_op=wrapper_op, body_block=body_block, + simt_launch_dims=simt_launch_dims, ) self._subkernel_stack.append(frame) try: @@ -413,9 +422,17 @@ def _outline_inline_subkernel(self, outline_frame: InlineSubkernelOutlineFrame) ) with InsertionPoint(outline_frame.wrapper_op.operation): - if role == "simt": - self._emit_simt_helper_launch_metadata() - func.CallOp(helper_fn, list(captures)) + if role == "simt" and outline_frame.simt_launch_dims is not None: + dim_x, dim_y, dim_z = _coerce_simt_launch_dims(outline_frame.simt_launch_dims) + Operation.create( + "pto.simt_launch", + attributes={"callee": FlatSymbolRefAttr.get(_symbol_name(helper_fn))}, + operands=[dim_x, dim_y, dim_z, *captures], + ) + else: + if role == "simt": + self._emit_simt_helper_launch_metadata() + func.CallOp(helper_fn, list(captures)) entry_block = helper_fn.add_entry_block() with InsertionPoint(entry_block): @@ -833,6 +850,13 @@ def validate_final_state(self) -> None: raise RuntimeError("PTODSL trace-session exited with an open subkernel lowering frame") if self._carry_loop_stack: raise RuntimeError("PTODSL trace-session exited with an open loop-carry lowering frame") + dyn_shared_memory_buf = getattr(self.module_spec, "dyn_shared_memory_buf", None) + if dyn_shared_memory_buf: + i64 = IntegerType.get_signless(64) + self.entry_function.attributes["dyn_shared_memory_buf"] = IntegerAttr.get( + i64, + dyn_shared_memory_buf, + ) def _coerce_simt_launch_dims(dims): diff --git a/ptodsl/ptodsl/_types.py b/ptodsl/ptodsl/_types.py index 289c9f936..4aa7fc3b8 100644 --- a/ptodsl/ptodsl/_types.py +++ b/ptodsl/ptodsl/_types.py @@ -35,6 +35,7 @@ def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): IntegerType, ShapedType, Type, + VectorType, ) # ── Address-space name → AddressSpace enum ─────────────────────────────────── @@ -152,6 +153,35 @@ def __repr__(self): return f"" +class _VecDescriptor(_DType): + def __init__(self, elem, lanes: int): + self._elem = elem + self._lanes = _validate_vec_lanes(lanes, context="pto.vec(...)") + + def resolve(self) -> Type: + elem = _ensure_non_storage_only_dtype(self._elem, context="pto.vec(...)") + return VectorType.get([self._lanes], elem) + + @property + def lanes(self) -> int: + return self._lanes + + @property + def elem(self): + return self._elem + + def __repr__(self): + return f"" + + +def _validate_vec_lanes(lanes: int, *, context: str) -> int: + if isinstance(lanes, bool) or not isinstance(lanes, int): + raise TypeError(f"{context} expects lanes to be a positive Python integer") + if lanes <= 0: + raise ValueError(f"{context} expects lanes to be positive") + return lanes + + def _resolve(dtype) -> Type: """Coerce a ``_DType`` descriptor or a concrete ``mlir.ir.Type`` to a Type.""" if isinstance(dtype, _DType): @@ -396,6 +426,11 @@ def vreg_type(lanes: int, elem) -> _VRegDescriptor: return _VRegDescriptor(lanes, elem) +def vec_type(elem, lanes: int) -> _VecDescriptor: + """Return a lazy descriptor for builtin ``vector`` values.""" + return _VecDescriptor(elem, lanes) + + def mask_type(bits: str = "b32") -> _MaskDescriptor: """Return a lazy descriptor for ``!pto.mask``.""" return _MaskDescriptor(bits) @@ -477,7 +512,7 @@ def part_tensor_view_type_from_dims(dims, elem) -> Type: "si8", "si16", "si32", "si64", "ui8", "ui16", "ui32", "ui64", "index", - "ptr", "vreg_type", "mask_type", + "ptr", "vreg_type", "vec_type", "mask_type", "tile_buf_type", "tensor_view_type", "tensor_view_type_from_dims", "part_tensor_view_type", "part_tensor_view_type_from_dims", ] diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 9b6caffa4..b499ea105 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -30,9 +30,10 @@ si8, si16, si32, si64, ui8, ui16, ui32, ui64, index, - ptr, vreg_type, mask_type, + ptr, vreg_type, vec_type, mask_type, _resolve, ) +from ._builtin_vector import vec # noqa: F401 from ._surface_types import ( # noqa: F401 const_expr, BarrierType, @@ -103,7 +104,7 @@ vaxpy, vaddrelu, vsubrelu, vsel, make_tensor_view, partition_view, - alloc_tile, + alloc_buffer, alloc_tile, tsort32, tmrgsort, tgather, mte_load, mte_store, mte_gm_ub, mte_ub_gm, mte_ub_ub, mte_ub_l1, mte_gm_l1, mte_l1_ub, mte_gm_l1_frac, mte_l1_bt, mte_l1_fb, mem_bar, @@ -143,6 +144,9 @@ LoopHandle, BranchHandle, ) +# ── All-reduce ───────────────────────────────────────────────────────────────── +from ._allreduce import simt_allreduce_max, simt_allreduce_min, simt_allreduce_sum # noqa: F401 + # ── Decorator ───────────────────────────────────────────────────────────────── from ._jit import jit, KernelHandle, merge_jit_modules # noqa: F401 from ._subkernels import cube, simd, simt # noqa: F401 diff --git a/ptodsl/ptodsl/scalar.py b/ptodsl/ptodsl/scalar.py index 962f9266a..ce9ed453b 100644 --- a/ptodsl/ptodsl/scalar.py +++ b/ptodsl/ptodsl/scalar.py @@ -24,12 +24,20 @@ emit_runtime_max, emit_runtime_min, ) -from ._surface_values import resolve_address_access, unwrap_surface_value, wrap_surface_value +from ._surface_values import ( + AddressOffsetValue, + AllocatedBufferValue, + VecValue, + resolve_address_access, + unwrap_surface_value, + wrap_surface_value, +) from ._types import _resolve from mlir.dialects import arith +from mlir.dialects import llvm from mlir.dialects import math -from mlir.ir import IndexType, MemRefType, Operation +from mlir.ir import IndexType, IntegerType, MemRefType, Operation, VectorType from mlir.dialects import pto as _pto @@ -120,10 +128,16 @@ def abs(value): return wrap_surface_value(emit_runtime_abs(unwrap_surface_value(value))) -def load(ptr_or_ref, offset=None): - """Load one scalar element from a PTODSL address view or tile element.""" +def load(ptr_or_ref, offset=None, *, contiguous=None): + """Load one scalar element or a contiguous builtin vector from a PTODSL address view.""" + width = _normalize_contiguous(contiguous, context="scalar.load(...)") buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) - result_type = _infer_buffer_element_type(buffer_value.type) + allocated_buffer = _allocated_buffer_target(ptr_or_ref) + result_type = _infer_buffer_element_type(buffer_value.type, allocated_buffer=allocated_buffer) + if width > 1: + return VecValue(_emit_contiguous_load(buffer_value, index_value, result_type, width)) + if _is_local_allocated_buffer(allocated_buffer): + return wrap_surface_value(_emit_llvm_load(buffer_value, index_value, result_type)) return wrap_surface_value(Operation.create( "pto.load", results=[result_type], @@ -131,23 +145,168 @@ def load(ptr_or_ref, offset=None): ).results[0]) -def store(value, ptr_or_ref, offset=None): - """Store one scalar element to a PTODSL address view or tile element.""" +def store(value, ptr_or_ref, offset=None, *, contiguous=None): + """Store one scalar element or a builtin vector to a PTODSL address view.""" buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) - elem_type = _infer_buffer_element_type(buffer_value.type) + allocated_buffer = _allocated_buffer_target(ptr_or_ref) + elem_type = _infer_buffer_element_type(buffer_value.type, allocated_buffer=allocated_buffer) + raw_value = unwrap_surface_value(value) + if hasattr(raw_value, "type") and VectorType.isinstance(raw_value.type): + vec_value = value if isinstance(value, VecValue) else VecValue(raw_value) + width = _normalize_contiguous(contiguous, context="scalar.store(...)", default=vec_value.lanes) + if width != vec_value.lanes: + raise ValueError( + f"scalar.store(..., contiguous={width}) does not match vector lane count {vec_value.lanes}" + ) + if vec_value.element_type != elem_type: + raise TypeError( + "scalar.store(vector, ...) element type must match the destination pointer element type: " + f"got {vec_value.element_type}, expected {elem_type}" + ) + _emit_contiguous_store(raw_value, buffer_value, index_value) + return + + width = _normalize_contiguous(contiguous, context="scalar.store(...)") + if width > 1: + raise TypeError("scalar.store(scalar, ..., contiguous=N) is not supported; pass a vector value") + if _is_local_allocated_buffer(allocated_buffer): + _emit_llvm_store( + coerce_scalar_to_type(value, elem_type, context="scalar.store(...)"), + buffer_value, + index_value, + elem_type, + ) + return Operation.create( "pto.store", operands=[buffer_value, index_value, coerce_scalar_to_type(value, elem_type, context="scalar.store(...)")], ) -def _infer_buffer_element_type(buffer_type): +def _normalize_contiguous(contiguous, *, context: str, default: int = 1) -> int: + if contiguous is None: + return default + if isinstance(contiguous, bool) or not isinstance(contiguous, int): + raise TypeError(f"{context} expects contiguous to be a positive Python integer") + if contiguous <= 0: + raise ValueError(f"{context} expects contiguous to be positive") + return contiguous + + +def _allocated_buffer_target(target): + if isinstance(target, AllocatedBufferValue): + return target + if isinstance(target, AddressOffsetValue) and isinstance(target.base, AllocatedBufferValue): + return target.base + return None + + +def _is_local_allocated_buffer(allocated_buffer) -> bool: + return allocated_buffer is not None + + +def _infer_buffer_element_type(buffer_type, *, allocated_buffer=None): + if allocated_buffer is not None: + return allocated_buffer.element_type try: return _pto.PtrType(buffer_type).element_type except Exception: return MemRefType(buffer_type).element_type +def _emit_contiguous_load(buffer_value, index_value, elem_type, width: int): + vector_type = VectorType.get([width], elem_type) + ptr_value = _emit_llvm_byte_pointer(buffer_value, index_value, elem_type) + return llvm.LoadOp(vector_type, ptr_value).res + + +def _emit_llvm_load(buffer_value, index_value, elem_type): + ptr_value = _emit_llvm_byte_pointer(buffer_value, index_value, elem_type) + return llvm.LoadOp(elem_type, ptr_value).res + + +def _emit_contiguous_store(vector_value, buffer_value, index_value): + elem_type = VectorType(vector_value.type).element_type + ptr_value = _emit_llvm_byte_pointer(buffer_value, index_value, elem_type) + llvm.StoreOp(vector_value, ptr_value) + + +def _emit_llvm_store(value, buffer_value, index_value, elem_type): + ptr_value = _emit_llvm_byte_pointer(buffer_value, index_value, elem_type) + llvm.StoreOp(value, ptr_value) + + +def _emit_llvm_byte_pointer(buffer_value, index_value, elem_type): + byte_offset = _emit_byte_offset(index_value, elem_type) + llvm_ptr_type = _as_llvm_ptr_type(buffer_value.type) + if llvm_ptr_type is not None: + llvm_base = buffer_value + else: + pto_ptr_type = _as_pto_ptr_type(buffer_value.type) + i64 = IntegerType.get_signless(64) + addr_as_i64 = _pto.CastPtrOp(i64, buffer_value).result + llvm_ptr_type = llvm.PointerType.get(_pto_ptr_llvm_address_space(pto_ptr_type)) + llvm_base = llvm.IntToPtrOp(llvm_ptr_type, addr_as_i64).res + return llvm.GEPOp( + llvm_ptr_type, + llvm_base, + [byte_offset], + [-2147483648], + IntegerType.get_signless(8), + ).res + + +def _as_llvm_ptr_type(type_obj): + try: + return llvm.PointerType(type_obj) + except Exception: + return None + + +def _emit_byte_offset(index_value, elem_type): + bytewidth = _element_bytewidth(elem_type) + bytewidth_const = arith.ConstantOp(IndexType.get(), bytewidth).result + byte_index = arith.MulIOp(index_value, bytewidth_const).result + return arith.IndexCastOp(IntegerType.get_signless(64), byte_index).result + + +def _as_pto_ptr_type(type_obj): + try: + return _pto.PtrType(type_obj) + except Exception as exc: + raise TypeError( + "contiguous scalar.load/store currently expects a PTO pointer-backed address" + ) from exc + + +def _pto_ptr_llvm_address_space(ptr_type) -> int: + memory_space = getattr(ptr_type, "memory_space", None) + value = getattr(memory_space, "value", None) + if value is not None: + return int(value) + text = str(ptr_type) + if ", ub>" in text or ", vec>" in text: + return 6 + if ", gm>" in text or text.endswith(">"): + return 1 + raise TypeError(f"unable to infer LLVM address space for pointer type {ptr_type}") + + +def _element_bytewidth(elem_type): + if str(elem_type) == "f32": + return 4 + if str(elem_type) in {"f16", "bf16"}: + return 2 + if IntegerType.isinstance(elem_type): + width = IntegerType(elem_type).width + if width % 8 != 0: + raise TypeError(f"unsupported sub-byte integer element type {elem_type}") + return width // 8 + if str(elem_type).startswith("f8") or str(elem_type).startswith("!pto."): + return 1 + raise TypeError(f"unsupported element type {elem_type}") + + __all__ = [ "muli", "addi", "subi", "index_cast", diff --git a/ptodsl/tests/test_allreduce.py b/ptodsl/tests/test_allreduce.py new file mode 100644 index 000000000..3b8ca120b --- /dev/null +++ b/ptodsl/tests/test_allreduce.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) + +from ptodsl import pto + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def main(): + from ptodsl._allreduce import simt_allreduce_sum, simt_allreduce_max, simt_allreduce_min + + # ══════════════════════════════════════════════════════════════════════════ + # Path 0: identity (threads <= scale) + # ══════════════════════════════════════════════════════════════════════════ + expect( + simt_allreduce_sum(1.0, threads=1, scale=1) == 1.0, + "identity: threads == scale", + ) + expect( + simt_allreduce_sum(1.0, threads=2, scale=2) == 1.0, + "identity: threads == scale (alt)", + ) + + # ══════════════════════════════════════════════════════════════════════════ + # validation errors + # ══════════════════════════════════════════════════════════════════════════ + + # threads % scale != 0 (validation now runs before identity shortcut) + try: + simt_allreduce_sum(1.0, threads=3, scale=2) + raise AssertionError("expected ValueError for threads % scale != 0") + except ValueError: + pass + + + # threads < 1 + try: + simt_allreduce_sum(1.0, threads=0, scale=1) + raise AssertionError("expected ValueError for threads < 1") + except ValueError: + pass + + # validation runs before identity: bad params not bypassed by threads<=scale + try: + simt_allreduce_sum(1.0, threads=1, scale=2) + raise AssertionError("expected ValueError for threads%scale!=0 (before identity)") + except ValueError: + pass + + # i32 dtype rejected — need a real JIT kernel so we get an MLIR i32 value + @pto.jit(target="a5") + def kernel_i32(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1, dtype=pto.i32) + _result = pto.simt_allreduce_sum(x, threads=32, scale=1) + + try: + kernel_i32.compile() + raise AssertionError("expected NotImplementedError for i32") + except NotImplementedError: + pass + + # ══════════════════════════════════════════════════════════════════════════ + # Path 1a: warp_reduce — hardware redux, groups == 1 (threads=32) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_warp(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + _ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, threads=32, scale=1) + + compiled_warp = kernel_warp.compile() + mlir_warp = compiled_warp.mlir_text() + expect("pto.redux_add" in mlir_warp, + "IR: redux_add in warp_reduce helper") + expect("pto.syncthreads" not in mlir_warp, + "IR: warp_reduce has no syncthreads") + expect("pto.shuffle_bfly" not in mlir_warp, + "IR: warp_reduce (groups=1) has no shuffle_bfly") + compiled_warp.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 1b: warp_reduce — hardware redux, groups > 1 (threads=16, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_warp_t16(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + _ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, threads=16, scale=1) + + compiled_warp_t16 = kernel_warp_t16.compile() + mlir_warp_t16 = compiled_warp_t16.mlir_text() + expect("pto.redux_add" in mlir_warp_t16, + "IR: redux_add for groups>1") + expect("arith.select" in mlir_warp_t16, + "IR: arith.select for group masking") + expect("pto.syncthreads" not in mlir_warp_t16, + "IR: warp_reduce (groups=2) has no syncthreads") + compiled_warp_t16.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 1c: warp_reduce — butterfly shuffle (threads=8, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_warp_t8(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + _ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, threads=8, scale=1) + + compiled_warp_t8 = kernel_warp_t8.compile() + mlir_warp_t8 = compiled_warp_t8.mlir_text() + expect("pto.shuffle_bfly" in mlir_warp_t8, + "IR: shuffle_bfly for butterfly path") + expect("pto.redux_add" not in mlir_warp_t8, + "IR: butterfly has no hardware redux") + expect("pto.syncthreads" not in mlir_warp_t8, + "IR: butterfly has no syncthreads") + compiled_warp_t8.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 1d: warp_reduce — butterfly with scale > 1 (threads=32, scale=2) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_warp_s2(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + _ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, threads=32, scale=2) + + compiled_warp_s2 = kernel_warp_s2.compile() + mlir_warp_s2 = compiled_warp_s2.mlir_text() + expect("pto.shuffle_bfly" in mlir_warp_s2, + "IR: shuffle_bfly for butterfly (scale>1)") + expect("pto.redux_add" not in mlir_warp_s2, + "IR: butterfly (scale>1) has no hardware redux") + compiled_warp_s2.verify() + + # ── warp_reduce: sum, f32, t=16, s=1, o=4 (non-zero thread_offset) ──────── + @pto.jit(target="a5") + def kernel_warp_o4(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + _ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, threads=16, scale=1, thread_offset=4) + + compiled_warp_o4 = kernel_warp_o4.compile() + mlir_warp_o4 = compiled_warp_o4.mlir_text() + expect("pto.get_tid_x" in mlir_warp_o4, + "IR: warp_reduce o=4 uses get_tid_x (not raw get_laneid)") + expect("arith.subi" in mlir_warp_o4, + "IR: warp_reduce o=4 uses subi for tx = tid_x - offset") + expect("arith.andi" in mlir_warp_o4, + "IR: warp_reduce o=4 uses andi to extract lane_in_warp") + compiled_warp_o4.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 2: ub_reduce — threads ≤ 32, non-power-of-2 (threads=6, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_ub6(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=6, scale=1) + + compiled_ub6 = kernel_ub6.compile() + mlir_ub6 = compiled_ub6.mlir_text() + expect("pto.syncthreads" in mlir_ub6, + "IR: ub_reduce has syncthreads") + expect("pto.store" in mlir_ub6, + "IR: ub_reduce has store (write to scratch)") + expect("pto.load" in mlir_ub6, + "IR: ub_reduce has load (read from scratch)") + syncthreads_count = mlir_ub6.count("pto.syncthreads") + expect(syncthreads_count == 4, + f"IR: ub_reduce has 4 syncthreads, got {syncthreads_count}") + compiled_ub6.verify() + + # ── ub_reduce: sum, f32, t=6, s=2 (scale > 1, non-pow2 threads) ───────── + @pto.jit(target="a5") + def kernel_ub6s2(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=6, scale=2) + + compiled_ub6s2 = kernel_ub6s2.compile() + mlir_ub6s2 = compiled_ub6s2.mlir_text() + expect("pto.syncthreads" in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has syncthreads") + expect("pto.store" in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has store") + expect("pto.load" in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has load") + expect("scf.for" in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has scf.for (sequential reduce loop)") + expect("pto.redux_add" not in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has no hardware redux") + expect("pto.shuffle_bfly" not in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 has no butterfly shuffle") + # scale>1 fixes: reducer uses lane < scale (ult), not lane_mod == 0 + expect("arith.cmpi slt" in mlir_ub6s2 or "arith.cmpi ult" in mlir_ub6s2, + "IR: ub_reduce t=6 s=2 reducer uses lane < scale") + compiled_ub6s2.verify() + + # ── ub_reduce: sum, f32, t=6, s=1, o=4 (non-zero thread_offset) ───────── + @pto.jit(target="a5") + def kernel_ub_o4(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=6, scale=1, + thread_offset=4) + + compiled_ub_o4 = kernel_ub_o4.compile() + mlir_ub_o4 = compiled_ub_o4.mlir_text() + expect("arith.subi" in mlir_ub_o4, + "IR: ub_reduce o=4 uses subi for tx = tid_x - offset") + compiled_ub_o4.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 3a: cross_warp_reduce — sum, f32, t=128, s=1, o=0 (baseline) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_128(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=128, scale=1) + + compiled = kernel_128.compile() + mlir = compiled.mlir_text() + + expect("pto.simt_entry" in mlir, + "IR: helper carries pto.simt_entry") + + for op_name in ( + "pto.redux_add", "pto.syncthreads", "pto.store", "pto.load", + "pto.get_tid_x", "pto.get_laneid", "arith.shrui", "scf.if", + ): + expect(op_name in mlir, f"IR: expected '{op_name}' in helper body") + + syncthreads_count = mlir.count("pto.syncthreads") + expect(syncthreads_count == 3, + f"IR: expected 3 syncthreads, got {syncthreads_count}") + + compiled.verify() + + # ── cross_warp: sum, f32, t=64 (2 warps) ──────────────────────────────── + @pto.jit(target="a5") + def kernel_64(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=64, scale=1) + + compiled_64 = kernel_64.compile() + mlir_64 = compiled_64.mlir_text() + compiled_64.verify() + + # ── cross_warp: sum, f32, t=256 (8 warps) ─────────────────────────────── + @pto.jit(target="a5") + def kernel_256(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=256, scale=1) + + compiled_256 = kernel_256.compile() + mlir_256 = compiled_256.mlir_text() + compiled_256.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 3b: cross_warp_reduce — scale > 1, scale*num_warps ≤ 32 + # (threads=128, scale=2, num_warps=4, total=8 ≤ 32) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_cw_s2(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=128, scale=2) + + compiled_cw_s2 = kernel_cw_s2.compile() + mlir_cw_s2 = compiled_cw_s2.mlir_text() + expect("pto.shuffle_bfly" in mlir_cw_s2, + "IR: cross_warp s=2 has shuffle_bfly (butterfly for per-warp + leader)") + expect("pto.syncthreads" in mlir_cw_s2, + "IR: cross_warp s=2 has syncthreads") + # scale > 1: per-warp uses butterfly, not hardware redux + compiled_cw_s2.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 3c: cross_warp_reduce — scale > 1, scale*num_warps > 32 (manual, sum) + # (threads=128, scale=16, num_warps=4, total=64 > 32) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_cw_s16(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=128, scale=16) + + compiled_cw_s16 = kernel_cw_s16.compile() + mlir_cw_s16 = compiled_cw_s16.mlir_text() + expect("pto.syncthreads" in mlir_cw_s16, + "IR: cross_warp s=16 has syncthreads") + compiled_cw_s16.verify() + + # ── cross_warp: sum, f32, t=128, s=1, o=4 (non-zero thread_offset) ───── + @pto.jit(target="a5") + def kernel_cw_o4(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=128, scale=1, + thread_offset=4) + + compiled_cw_o4 = kernel_cw_o4.compile() + mlir_cw_o4 = compiled_cw_o4.mlir_text() + expect("pto.get_tid_x" in mlir_cw_o4, + "IR: cross_warp o=4 uses get_tid_x") + expect("arith.subi" in mlir_cw_o4, + "IR: cross_warp o=4 uses subi for tx = tid_x - offset") + compiled_cw_o4.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # Path 4: ub_reduce fallback — threads > 32, non-power-of-2 + # (threads=48, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_ub48(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=48, scale=1) + + compiled_ub48 = kernel_ub48.compile() + mlir_ub48 = compiled_ub48.mlir_text() + expect("pto.syncthreads" in mlir_ub48, + "IR: ub_reduce fallback has syncthreads") + expect("pto.store" in mlir_ub48, + "IR: ub_reduce fallback has store") + expect("pto.load" in mlir_ub48, + "IR: ub_reduce fallback has load") + compiled_ub48.verify() + + # ══════════════════════════════════════════════════════════════════════════ + # ══════════════════════════════════════════════════════════════════════════ + + @pto.jit(target="a5") + def kernel_reuse(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x1 = pto.const(1.0, dtype=pto.f32) + _r1 = pto.simt_allreduce_sum(x1, scratch=ub_scratch, threads=128, scale=1) + x2 = pto.const(2.0, dtype=pto.f32) + _r2 = pto.simt_allreduce_sum(x2, scratch=ub_scratch, threads=128, scale=1) + + compiled2 = kernel_reuse.compile() + mlir2 = compiled2.mlir_text() + + compiled2.verify() + + + # ══════════════════════════════════════════════════════════════════════════ + # scratch required for ub_reduce and cross_warp paths + # ══════════════════════════════════════════════════════════════════════════ + + # cross_warp requires scratch — use a real JIT kernel so the error + # originates from _dispatch_allreduce_helper, not from a bare Python float. + @pto.jit(target="a5") + def kernel_no_scratch_cw(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=None, threads=128, scale=1) + + try: + kernel_no_scratch_cw.compile() + raise AssertionError("expected ValueError for missing scratch (cross_warp)") + except ValueError as e: + expect("requires a UB scratch buffer" in str(e), + f"error message should mention scratch (cross_warp), got: {e}") + + # ub_reduce (non-pow2) requires scratch + @pto.jit(target="a5") + def kernel_no_scratch_ub(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=None, threads=6, scale=1) + + try: + kernel_no_scratch_ub.compile() + raise AssertionError("expected ValueError for missing scratch (ub_reduce)") + except ValueError as e: + expect("requires a UB scratch buffer" in str(e), + f"error message should mention scratch (ub_reduce), got: {e}") + + # scratch must be a pto.ptr type — PTODSL scalar.load/store catch this + @pto.jit(target="a5") + def kernel_non_ptr(): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + not_ptr = pto.const(0, dtype=pto.i32) + _result = pto.simt_allreduce_sum(x, scratch=not_ptr, threads=6, scale=1) + + try: + kernel_non_ptr.compile() + raise AssertionError("expected error for non-ptr scratch") + except Exception: + pass # PTODSL scalar.store / resolve_address_access catches this + + # cross_warp: gm scratch (wrong memory space) should be rejected + @pto.jit(target="a5") + def kernel_gm_scratch(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=scratch_gm, threads=128, scale=1) + + try: + kernel_gm_scratch.compile() + raise AssertionError("expected error for gm scratch") + except Exception as e: + expect("ub" in str(e).lower() or "vec" in str(e).lower() or "address space" in str(e).lower() + or "memory" in str(e).lower(), + f"gm scratch error should mention address space, got: {e}") + + # cross_warp: i32 scratch with f32 x (dtype mismatch) should be rejected + @pto.jit(target="a5") + def kernel_dtype_mismatch(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_i32 = pto.castptr(zero_u64, pto.ptr(pto.i32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_i32, threads=128, scale=1) + + try: + kernel_dtype_mismatch.compile() + raise AssertionError("expected TypeError for dtype mismatch scratch") + except TypeError as e: + err = str(e) + expect("cannot coerce" in err.lower() or "element type" in err.lower() + or "mismatch" in err.lower(), + f"dtype mismatch should mention type, got: {e}") + + # ══════════════════════════════════════════════════════════════════════════ + # Max reducer — Path 1a: warp_reduce, hw redux (threads=32, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + @pto.jit(target="a5") + def kernel_max_warp_hw(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_max(x, threads=32, scale=1) + + compiled_max_warp = kernel_max_warp_hw.compile() + mlir_max_warp = compiled_max_warp.mlir_text() + + expect( + "pto.redux_max" in mlir_max_warp, + "Path 1a (max): IR must contain pto.redux_max", + ) + expect( + "pto.syncthreads" not in mlir_max_warp, + "Path 1a (max): single-warp hw reduce needs no syncthreads", + ) + + # ── Max reducer — Path 1c: warp_reduce, butterfly (threads=8, scale=1) ── + @pto.jit(target="a5") + def kernel_max_butterfly(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_max(x, threads=8, scale=1) + + compiled_max_bfly = kernel_max_butterfly.compile() + mlir_max_bfly = str(compiled_max_bfly.mlir_text()) + + expect( + "arith.maximumf" in mlir_max_bfly, + "Path 1c (max): butterfly must emit arith.maximumf for element-wise max", + ) + expect( + "pto.shuffle_bfly" in mlir_max_bfly, + "Path 1c (max): butterfly must use pto.shuffle_bfly", + ) + expect( + "pto.redux_max" not in mlir_max_bfly, + "Path 1c (max): butterfly path should NOT use hw redux", + ) + + # ── Max reducer — Path 3: cross_warp_reduce (threads=128, scale=1) ── + @pto.jit(target="a5") + def kernel_max_cross_warp(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_max(x, scratch=ub_scratch, threads=128, scale=1) + + compiled_max_cw = kernel_max_cross_warp.compile() + mlir_max_cw = str(compiled_max_cw.mlir_text()) + + expect( + "pto.redux_max" in mlir_max_cw, + "Path 3 (max): cross-warp IR must contain pto.redux_max", + ) + expect( + "pto.syncthreads" in mlir_max_cw, + "Path 3 (max): cross-warp needs syncthreads barriers", + ) + + # ══════════════════════════════════════════════════════════════════════════ + # Min reducer — Path 1a: warp_reduce, hw redux (threads=32, scale=1) + # ══════════════════════════════════════════════════════════════════════════ + @pto.jit(target="a5") + def kernel_min_warp_hw(scratch_gm: pto.ptr(pto.f32, "gm")): + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_min(x, threads=32, scale=1) + + compiled_min_warp = kernel_min_warp_hw.compile() + mlir_min_warp = str(compiled_min_warp.mlir_text()) + + expect( + "pto.redux_min" in mlir_min_warp, + "Path 1a (min): IR must contain pto.redux_min", + ) + expect( + "pto.syncthreads" not in mlir_min_warp, + "Path 1a (min): single-warp hw reduce needs no syncthreads", + ) + + # ── Min reducer — Path 4 (ub_reduce fallback): threads=48, non-pow2 ── + @pto.jit(target="a5") + def kernel_min_ub(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_min(x, scratch=ub_scratch, threads=48, scale=1) + + compiled_min_ub = kernel_min_ub.compile() + mlir_min_ub = str(compiled_min_ub.mlir_text()) + + expect( + "arith.minimumf" in mlir_min_ub, + "Path 4 (min): ub_reduce fallback must emit arith.minimumf", + ) + + # ── Identity smoke tests for max/min ─────────────────────────────────── + expect( + simt_allreduce_max(1.0, threads=1, scale=1) == 1.0, + "Path 0 (max): threads <= scale returns identity (value unchanged)", + ) + expect( + simt_allreduce_min(1.0, threads=2, scale=2) == 1.0, + "Path 0 (min): threads <= scale returns identity (value unchanged)", + ) + + # ══════════════════════════════════════════════════════════════════════════ + # Lowering verification — ptoas → bisheng (full AOT compilation) + # + # Tests that the allreduce MLIR survives the complete ptoas pipeline: + # MLIR (PTO dialect) → VPTO passes → LLVM IR → bisheng device codegen + # + # KNOWN TOOLCHAIN ISSUES (bisheng, not allreduce): + # a) bisheng stack-smashing on SIMT code that stores to GM + # b) bisheng stack-smashing on cross-warp scratch-buffer code (≥ 128 lanes) + # + # These are bisheng bugs — ptoas VPTO lowering succeeds; the crash is + # in the device LLVM→object step inside bisheng. Verified by: + # ptoas --emit-vpto-llvm-ir → valid LLVM IR (no crash) + # ptoas -o kernel.o → bisheng crash during LLVM→object + # ══════════════════════════════════════════════════════════════════════════ + + import subprocess + import tempfile + from pathlib import Path + + def _ptoas_binary() -> Path: + for p in [ + Path(__file__).resolve().parents[2] / "build" / "tools" / "ptoas" / "ptoas", + ]: + if p.is_file(): + return p + raise RuntimeError( + "ptoas binary not found; run `source scripts/ptoas_env.sh` or build ptoas" + ) + + def _lower_and_check(compiled, case_label: str, expect_pass: bool = True) -> bool: + """Run ``ptoas`` lowering on *compiled* MLIR. Returns True on success.""" + ptoas = _ptoas_binary() + mlir_text = compiled.mlir_text() + with tempfile.TemporaryDirectory() as tmpdir: + mlir_path = Path(tmpdir) / "kernel.mlir" + obj_path = Path(tmpdir) / "kernel.o" + mlir_path.write_text(mlir_text) + result = subprocess.run( + [str(ptoas), "--pto-arch=a5", "--pto-backend=vpto", + "--enable-tile-op-expand", + str(mlir_path), "-o", str(obj_path)], + capture_output=True, text=True, + ) + ok = result.returncode == 0 and obj_path.is_file() + if ok: + return True + bisheng_crash = "stack smashing" in result.stderr or "exit code 134" in result.stderr + tag = "SKIP (bisheng bug)" if bisheng_crash else "FAIL" + if expect_pass and not bisheng_crash: + # Unexpected failure — report loudly + sys.stderr.write( + f"\n [{tag}] {case_label} (exit={result.returncode})\n" + f" STDERR: {result.stderr[:500]}\n" + ) + else: + print(f" [{tag}] {case_label}") + return False + + # ── Warp-reduce (≤ 32 lanes, NO scratch, NO GM store) ── + # These are the simplest kernels — they only compute a value and return + # from the SIMT body without writing to GM. They MUST lower cleanly + # because they avoid both known bisheng issues. + expect( + _lower_and_check(kernel_warp.compile(), "warp_sum_t32"), + "lowering: warp_sum (32 lanes, hw redux, no GM store) must pass", + ) + expect( + _lower_and_check(kernel_max_warp_hw.compile(), "warp_max_t32"), + "lowering: warp_max (32 lanes, hw redux, no GM store) must pass", + ) + expect( + _lower_and_check(kernel_min_warp_hw.compile(), "warp_min_t32"), + "lowering: warp_min (32 lanes, hw redux, no GM store) must pass", + ) + + # ── Cross-warp (128 lanes, UB scratch) — known bisheng crash ── + # ptoas VPTO lowering succeeds; bisheng crashes on the device LLVM IR. + @pto.jit(target="a5") + def _kernel_cross_lowering(scratch_gm: pto.ptr(pto.f32, "gm")): + zero_u64 = pto.const(0, dtype=pto.ui64) + ub_scratch = pto.castptr(zero_u64, pto.ptr(pto.f32, "ub")) + with pto.simt(): + x = pto.const(1.0, dtype=pto.f32) + _result = pto.simt_allreduce_sum(x, scratch=ub_scratch, threads=128, scale=1) + _lower_and_check(_kernel_cross_lowering.compile(), "cross_sum_t128", expect_pass=False) + + print("ptodsl_allreduce: PASS") + + +if __name__ == "__main__": + main() diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 2157acdb1..dd759bcf0 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -694,6 +694,17 @@ def inline_subkernel_scope_probe(*, TRACE_TOKEN: pto.const_expr = 0): pto.pipe_barrier(pto.Pipe.ALL) +@pto.jit(target="a5", mode="explicit") +def inline_simt_launch_dims_probe( + gm: pto.ptr(pto.i32, "gm"), + *, + TRACE_TOKEN: pto.const_expr = 0, +): + with pto.simt(32, 2, 1): + tid = pto.get_tid_x() + pto.stg(tid, gm, scalar.index_cast(tid)) + + @pto.simt def simt_tid_probe(): pto.get_tid_x() @@ -864,6 +875,50 @@ def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): simt_tid_probe() +@pto.simt +def alloc_buffer_local_helper(): + _ = pto.alloc_buffer((32,), pto.f32) + + +@pto.jit(target="a5", mode="explicit") +def alloc_buffer_local_probe(): + alloc_buffer_local_helper() + + +@pto.simt +def rmsnorm_alloc_buffer_frag_helper( + w_ub: pto.ptr(pto.f32, pto.MemorySpace.UB), + x_ub: pto.ptr(pto.f32, pto.MemorySpace.UB), +): + _ = pto.get_tid_x() + _ = w_ub + _ = x_ub + _ = pto.alloc_buffer((32,), pto.f32) + _ = pto.alloc_buffer((1,), pto.f32) + + +@pto.jit(target="a5", mode="explicit", dyn_shared_memory_buf=82496) +def rmsnorm_alloc_buffer_layout_probe( + X: pto.ptr(pto.f32, "gm"), + W: pto.ptr(pto.f32, "gm"), + Y: pto.ptr(pto.f32, "gm"), + RSTD: pto.ptr(pto.f32, "gm"), +): + ub_base = pto.castptr(pto.const(0, dtype=pto.ui64), pto.ptr(pto.f32, "ub")) + w_ub = pto.addptr(ub_base, 0) + reduce_scratch = pto.addptr(ub_base, 4096) + x_ub = pto.addptr(ub_base, 4224) + y_ub = pto.addptr(ub_base, 12416) + rstd_ub = pto.addptr(ub_base, 20608) + + pto.mte_gm_ub(W, w_ub, 0, 4096 * 4, nburst=(1, 0, 0)) + pto.mte_gm_ub(X, x_ub, 0, 4096 * 4, nburst=(1, 0, 0)) + rmsnorm_alloc_buffer_frag_helper(w_ub, x_ub) + pto.mte_ub_gm(y_ub, Y, 4096 * 4, nburst=(1, 0, 0)) + pto.mte_ub_gm(rstd_ub, RSTD, 4, nburst=(1, 0, 0)) + _ = reduce_scratch + + @pto.jit(target="a5") def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_query_probe, dims=(32, 2, 1)) @@ -1563,6 +1618,45 @@ def scalar_pointer_offset_probe(): _ = valid_cols +@pto.jit(target="a5") +def scalar_contiguous_vector_probe(): + data_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32, valid_shape=[1, 16]) + data_ptr = data_tile.as_ptr() + x4 = scalar.load(data_ptr, 0, contiguous=4) + scale4 = pto.vec(pto.f32, 4, init=1.0) + y4 = x4 * scale4 + scalar.store(y4, data_ptr, 4) + + +@pto.simt +def scalar_contiguous_local_alloc_buffer_helper(): + data = pto.alloc_buffer((16,), pto.f32) + x4 = scalar.load(data, 0, contiguous=4) + scale4 = pto.vec(pto.f32, 4, init=1.0) + y4 = x4 * scale4 + scalar.store(y4, data, 4) + + +@pto.jit(target="a5", mode="explicit") +def scalar_contiguous_local_alloc_buffer_probe(): + scalar_contiguous_local_alloc_buffer_helper() + + +@pto.jit(target="a5") +def scalar_contiguous_width_mismatch_probe(): + data_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32, valid_shape=[1, 16]) + data_ptr = data_tile.as_ptr() + x4 = scalar.load(data_ptr, 0, contiguous=4) + scalar.store(x4, data_ptr, 4, contiguous=2) + + +@pto.jit(target="a5") +def scalar_contiguous_scalar_store_probe(): + data_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32, valid_shape=[1, 16]) + data_ptr = data_tile.as_ptr() + scalar.store(1.0, data_ptr, 0, contiguous=4) + + @pto.jit(target="a5") def addptr_surface_probe(): meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 4]) @@ -2922,6 +3016,7 @@ def main() -> None: fixed_integer_index_coercion_probe.verify() integer_loop_bound_probe.verify() scalar_pointer_offset_probe.verify() + scalar_contiguous_vector_probe.verify() addptr_surface_probe.verify() simt_pointer_offset_probe.verify() scalar_store_element_coercion_probe.verify() @@ -3887,6 +3982,31 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "outlined inline helpers should preserve the authored SIMD/Cube sections and SIMT scalar ops", ) + inline_simt_launch_text = inline_simt_launch_dims_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(inline_simt_launch_text, "inline simt launch-dims specialization") + expect( + re.search(r"pto\.simt_launch @inline_simt_[0-9]+__ptodsl_[0-9a-f]+<<<", inline_simt_launch_text) + is not None, + "with pto.simt(dim_x, dim_y, dim_z) should emit VPTO simt_launch sugar", + ) + expect( + "pto.store_vfsimt_info" not in inline_simt_launch_text, + "with pto.simt(dim_x, dim_y, dim_z) should leave launch metadata to simt_launch expansion", + ) + expect( + re.search( + r"func\.func @inline_simt_[0-9]+__ptodsl_[0-9a-f]+\(%arg0: !pto\.ptr\) attributes \{[^}]*pto\.simt_entry[^}]*\}", + inline_simt_launch_text, + ) + is not None, + "inline SIMT launch-dims helper should capture enclosing values as helper arguments", + ) + expect_raises( + TypeError, + lambda: pto.simt(32, 1), + "expects exactly three", + ) + simt_text = simt_helper_lowering_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(simt_text, "simt helper lowering specialization") expect( @@ -3919,6 +4039,41 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): expect("pto.get_tid_y" in simt_text, "SIMT helper body should contain pto.get_tid_y") expect("pto.get_tid_z" in simt_text, "SIMT helper body should contain pto.get_tid_z") + alloc_buffer_local_text = alloc_buffer_local_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(alloc_buffer_local_text, "alloc_buffer local specialization") + expect( + "llvm.alloca" in alloc_buffer_local_text and "x f32" in alloc_buffer_local_text, + "alloc_buffer should lower to an LLVM stack allocation in the SIMT helper", + ) + expect( + re.search( + r"func\.func @alloc_buffer_local_helper__simt_\d+\(\) attributes \{pto\.simt_entry\}", + alloc_buffer_local_text, + ) + is not None, + "alloc_buffer probe should keep allocation inside the SIMT helper body", + ) + rmsnorm_alloc_buffer_text = rmsnorm_alloc_buffer_layout_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(rmsnorm_alloc_buffer_text, "RMSNorm hand-authored UB layout specialization") + expect( + "dyn_shared_memory_buf = 82496 : i64" in rmsnorm_alloc_buffer_text, + "RMSNorm hand-authored UB layout should declare the expanded RMSNorm kernel scratch size", + ) + for expected_offset in (4096, 4224, 12416, 20608): + expect( + f"arith.constant {expected_offset} : index" in rmsnorm_alloc_buffer_text, + f"RMSNorm hand-authored UB layout should materialize f32 offset {expected_offset}", + ) + expect( + rmsnorm_alloc_buffer_text.count("llvm.alloca") == 2, + "RMSNorm alloc_buffer fragment helper should allocate x_frag and sum_sq locally", + ) + expect( + re.search(r"call @rmsnorm_alloc_buffer_frag_helper__simt_\d+\(", rmsnorm_alloc_buffer_text) + is not None, + "RMSNorm alloc_buffer layout should pass UB scratch pointers through the existing SIMT helper call path", + ) + simt_launch_text = simt_explicit_launch_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(simt_launch_text, "explicit simt launch specialization") expect( @@ -4596,6 +4751,40 @@ def _enter_inline_simt_with_resource_attr(): "scalar.load(ptr + 2) should lower as element offset 2", ) + scalar_contiguous_text = scalar_contiguous_vector_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(scalar_contiguous_text, "scalar contiguous vector specialization") + expect("llvm.load" in scalar_contiguous_text, "scalar.load(..., contiguous=N) should lower to llvm.load") + expect("llvm.store" in scalar_contiguous_text, "scalar.store(vector, ...) should lower to llvm.store") + expect("vector<4xf32>" in scalar_contiguous_text, "contiguous=4 over f32 should produce vector<4xf32>") + expect("llvm.insertelement" in scalar_contiguous_text, "pto.vec(..., init=scalar) should broadcast with insertelement") + expect("arith.mulf" in scalar_contiguous_text, "VecValue multiplication should lower to arith.mulf") + expect( + "pto.load" not in scalar_contiguous_text and "pto.store" not in scalar_contiguous_text, + "contiguous vector memory access should not lower through scalar pto.load/store", + ) + scalar_contiguous_local_text = scalar_contiguous_local_alloc_buffer_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify( + scalar_contiguous_local_text, + "scalar contiguous local alloc_buffer specialization", + ) + expect("llvm.alloca" in scalar_contiguous_local_text, "local alloc_buffer should lower to llvm.alloca") + expect("llvm.load" in scalar_contiguous_local_text, "local alloc_buffer contiguous load should lower to llvm.load") + expect("llvm.store" in scalar_contiguous_local_text, "local alloc_buffer contiguous store should lower to llvm.store") + expect( + "vector<4xf32>" in scalar_contiguous_local_text, + "local alloc_buffer contiguous access should preserve vector lane type", + ) + expect_raises( + ValueError, + lambda: scalar_contiguous_width_mismatch_probe.compile(), + "does not match vector lane count", + ) + expect_raises( + TypeError, + lambda: scalar_contiguous_scalar_store_probe.compile(), + "scalar.store(scalar, ..., contiguous=N) is not supported", + ) + addptr_surface_text = addptr_surface_probe.compile().mlir_text() expect_parse_roundtrip_and_verify(addptr_surface_text, "addptr surface specialization") expect( diff --git a/ptodsl/tests/test_rmsnorm_example_compile.py b/ptodsl/tests/test_rmsnorm_example_compile.py new file mode 100644 index 000000000..b815612e3 --- /dev/null +++ b/ptodsl/tests/test_rmsnorm_example_compile.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +import re +import sys + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT / "ptodsl")) + +from mlir.ir import Module +from ptodsl._bootstrap import make_context + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def expect_raises(exc_type, func, message_substring: str | None = None) -> Exception: + try: + func() + except exc_type as exc: + if message_substring is not None and message_substring not in str(exc): + raise AssertionError( + f"expected {exc_type.__name__} containing {message_substring!r}, got {exc!r}" + ) from exc + return exc + except Exception as exc: + raise AssertionError( + f"expected {exc_type.__name__}, got {exc.__class__.__name__}: {exc}" + ) from exc + raise AssertionError(f"expected {exc_type.__name__} to be raised") + + +def expect_parse_roundtrip_and_verify(text: str, label: str) -> None: + with make_context() as ctx: + parsed = Module.parse(text, ctx) + parsed.operation.verify() + roundtrip_text = str(parsed) + expect( + roundtrip_text == text, + f"{label} should survive Module.parse(...) round-trip without textual drift", + ) + + +def load_rmsnorm_example(): + example_path = REPO_ROOT / "ptodsl" / "examples" / "rms_norm" / "rmsnorm_alloc_buffer_simt.py" + expect(example_path.is_file(), f"RMSNorm example is missing: {example_path}") + + spec = spec_from_file_location("ptodsl_rmsnorm_alloc_buffer_simt", example_path) + expect(spec is not None and spec.loader is not None, f"unable to create import spec for {example_path}") + module = module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def check_variant(compiled, *, label: str, vector_type: str, ub_size: int) -> None: + compiled.verify() + text = compiled.mlir_text() + expect_parse_roundtrip_and_verify(text, f"RMSNorm {label} MLIR") + + expect("func.func @rmsnorm_4096_alloc_buffer_simt_context_kernel" in text, f"{label}: missing entry") + expect(f"dyn_shared_memory_buf = {ub_size} : i64" in text, f"{label}: unexpected UB scratch size") + expect("scf.for" in text, f"{label}: tokens_per_core loop should lower to scf.for") + expect("pto.mte_gm_ub" in text, f"{label}: missing GM->UB transfer") + expect("pto.mte_ub_gm" in text, f"{label}: missing UB->GM transfer") + expect("pto.simt_launch @rmsnorm_simt_token_body__simt_" in text, + f"{label}: indexed SIMT call should lower to an explicit token simt_launch op") + expect("pto.simt_launch @inline_simt_" not in text, + f"{label}: token SIMT body should be emitted as the named helper, not an inline helper") + expect("pto.store_vfsimt_info" not in text, + f"{label}: explicit simt_launch dims should not emit caller-side store_vfsimt_info") + expect("pto.set_flag[, , ]" in text, + f"{label}: W load should signal completion before token processing") + expect("pto.wait_flag[, , ]" in text, + f"{label}: token processing should start after the W load completes") + expect("pto.set_flag[, , ]" in text, + f"{label}: missing V->MTE2 ping-pong priming flag") + expect("pto.set_flag[, , ]" in text, + f"{label}: missing MTE3->V pong priming flag") + expect("pto.set_flag_dyn" in text, f"{label}: token loop should lower dynamic set_flag ops") + expect("pto.wait_flag_dyn" in text, f"{label}: token loop should lower dynamic wait_flag ops") + expect(vector_type in text, f"{label}: missing contiguous vector access type {vector_type}") + expect("__tl_allreduce_sum" not in text, + f"{label}: PR3 allreduce should inline the reduce sequence into the SIMT body") + expect("pto.redux_add" in text, f"{label}: PR3 inline allreduce should use redux_add") + expect("pto.syncthreads" in text, f"{label}: PR3 inline allreduce should synchronize through UB scratch") + expect("pto.sqrt" in text, f"{label}: RMSNorm runtime sqrt should lower through the PTO SIMT sqrt op") + expect("math.sqrt" not in text, f"{label}: RMSNorm SIMT helper should not leave math.sqrt in the MLIR") + + expect("w_frag" not in text, f"{label}: W should be read directly from UB, not from a local fragment") + expect( + re.search( + r"func\.func @rmsnorm_simt_token_body__simt_[^{]+\{(?:(?!func\.func @).)*" + r"llvm\.alloca(?:(?!func\.func @).)*llvm\.alloca", + text, + re.S, + ) + is not None, + f"{label}: x_frag and sum_sq should be allocated inside the token SIMT helper", + ) + expect( + re.search( + rf"llvm\.insertelement .* : {re.escape(vector_type)}(?:(?!func\.func @).)*" + rf"arith\.mulf .* : {re.escape(vector_type)}(?:(?!func\.func @).)*" + rf"arith\.mulf .* : {re.escape(vector_type)}(?:(?!func\.func @).)*" + rf"llvm\.store .* : {re.escape(vector_type)}", + text, + re.S, + ) + is not None, + f"{label}: y = x * rstd * w should lower as vector broadcast/mul/store", + ) + + +def main() -> None: + example = load_rmsnorm_example() + + expect(hasattr(example, "build_x128"), "RMSNorm example should export build_x128()") + expect(hasattr(example, "build_x64"), "RMSNorm example should export build_x64()") + expect_raises( + AssertionError, + lambda: example.rmsnorm_4096_alloc_buffer_simt_context_kernel.compile( + threads=128, + rounds=16, + lanes=2, + hidden_size=4097, + ), + "threads * rounds * lanes must equal hidden_size", + ) + + check_variant( + example.build_x128(), + label="x128", + vector_type="vector<4xf32>", + ub_size=82496, + ) + check_variant( + example.build_x64(), + label="x64", + vector_type="vector<4xf32>", + ub_size=82496, + ) + + print("ptodsl_rmsnorm_example_compile: PASS") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_max/compare.py b/test/vpto/cases/micro-op/simt/allreduce_cross_max/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_max/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_max/golden.py b/test/vpto/cases/micro-op/simt/allreduce_cross_max/golden.py new file mode 100644 index 000000000..2dca5fb33 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_max/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 128 +EXPECTED = 1.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_max/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_cross_max/kernel.pto new file mode 100644 index 000000000..e9b2842bf --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_max/kernel.pto @@ -0,0 +1,65 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %0 = builtin.unrealized_conversion_cast %c0_i64 : i64 to ui64 + %1 = pto.castptr %0 : ui64 -> !pto.ptr + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c128_i32, %c1_i32, %c1_i32_0>>>(%1, %arg0) : (!pto.ptr, !pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c5_i32 = arith.constant 5 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0xFF800000 : f32 + %1 = pto.get_tid_x : i32 + %2 = arith.shrui %1, %c5_i32 : i32 + %3 = pto.get_laneid : i32 + %4 = pto.redux_max %cst : f32 -> f32 + %5 = arith.cmpi ult, %3, %c1_i32 : i32 + scf.if %5 { + %13 = arith.muli %2, %c1_i32 : i32 + %14 = arith.addi %13, %3 : i32 + %15 = arith.index_cast %14 : i32 to index + pto.store %4, %arg0[%15] : !pto.ptr, f32 + } + pto.syncthreads + %6 = arith.cmpi ult, %1, %c32_i32 : i32 + %7 = scf.if %6 -> (f32) { + %13 = arith.cmpi ult, %3, %c4_i32 : i32 + %14 = scf.if %13 -> (f32) { + %16 = arith.index_cast %3 : i32 to index + %17 = pto.load %arg0[%16] : !pto.ptr -> f32 + scf.yield %17 : f32 + } else { + scf.yield %cst_0 : f32 + } + %15 = pto.redux_max %14 : f32 -> f32 + scf.yield %15 : f32 + } else { + scf.yield %cst_0 : f32 + } + %8 = arith.cmpi ult, %1, %c1_i32 : i32 + scf.if %8 { + %13 = arith.index_cast %1 : i32 to index + pto.store %7, %arg0[%13] : !pto.ptr, f32 + } + pto.syncthreads + %9 = arith.remui %1, %c1_i32 : i32 + %10 = arith.index_cast %9 : i32 to index + %11 = pto.load %arg0[%10] : !pto.ptr -> f32 + pto.syncthreads + %12 = arith.index_cast %0 : i32 to index + pto.store %11, %arg1[%12] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_max/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_max/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_max/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_max/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_max/main.cpp new file mode 100644 index 000000000..05dad144b --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_max/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 128; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_min/compare.py b/test/vpto/cases/micro-op/simt/allreduce_cross_min/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_min/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_min/golden.py b/test/vpto/cases/micro-op/simt/allreduce_cross_min/golden.py new file mode 100644 index 000000000..2dca5fb33 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_min/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 128 +EXPECTED = 1.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_min/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_cross_min/kernel.pto new file mode 100644 index 000000000..f9353c1b7 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_min/kernel.pto @@ -0,0 +1,65 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %0 = builtin.unrealized_conversion_cast %c0_i64 : i64 to ui64 + %1 = pto.castptr %0 : ui64 -> !pto.ptr + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c128_i32, %c1_i32, %c1_i32_0>>>(%1, %arg0) : (!pto.ptr, !pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c5_i32 = arith.constant 5 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0x7F800000 : f32 + %1 = pto.get_tid_x : i32 + %2 = arith.shrui %1, %c5_i32 : i32 + %3 = pto.get_laneid : i32 + %4 = pto.redux_min %cst : f32 -> f32 + %5 = arith.cmpi ult, %3, %c1_i32 : i32 + scf.if %5 { + %13 = arith.muli %2, %c1_i32 : i32 + %14 = arith.addi %13, %3 : i32 + %15 = arith.index_cast %14 : i32 to index + pto.store %4, %arg0[%15] : !pto.ptr, f32 + } + pto.syncthreads + %6 = arith.cmpi ult, %1, %c32_i32 : i32 + %7 = scf.if %6 -> (f32) { + %13 = arith.cmpi ult, %3, %c4_i32 : i32 + %14 = scf.if %13 -> (f32) { + %16 = arith.index_cast %3 : i32 to index + %17 = pto.load %arg0[%16] : !pto.ptr -> f32 + scf.yield %17 : f32 + } else { + scf.yield %cst_0 : f32 + } + %15 = pto.redux_min %14 : f32 -> f32 + scf.yield %15 : f32 + } else { + scf.yield %cst_0 : f32 + } + %8 = arith.cmpi ult, %1, %c1_i32 : i32 + scf.if %8 { + %13 = arith.index_cast %1 : i32 to index + pto.store %7, %arg0[%13] : !pto.ptr, f32 + } + pto.syncthreads + %9 = arith.remui %1, %c1_i32 : i32 + %10 = arith.index_cast %9 : i32 to index + %11 = pto.load %arg0[%10] : !pto.ptr -> f32 + pto.syncthreads + %12 = arith.index_cast %0 : i32 to index + pto.store %11, %arg1[%12] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_min/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_min/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_min/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_min/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_min/main.cpp new file mode 100644 index 000000000..05dad144b --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_min/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 128; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_sum/compare.py b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_sum/golden.py b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/golden.py new file mode 100644 index 000000000..ce2b4c57d --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 128 +EXPECTED = 128.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_sum/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/kernel.pto new file mode 100644 index 000000000..53f01c524 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/kernel.pto @@ -0,0 +1,65 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %0 = builtin.unrealized_conversion_cast %c0_i64 : i64 to ui64 + %1 = pto.castptr %0 : ui64 -> !pto.ptr + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c128_i32, %c1_i32, %c1_i32_0>>>(%1, %arg0) : (!pto.ptr, !pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c5_i32 = arith.constant 5 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %1 = pto.get_tid_x : i32 + %2 = arith.shrui %1, %c5_i32 : i32 + %3 = pto.get_laneid : i32 + %4 = pto.redux_add %cst : f32 -> f32 + %5 = arith.cmpi ult, %3, %c1_i32 : i32 + scf.if %5 { + %13 = arith.muli %2, %c1_i32 : i32 + %14 = arith.addi %13, %3 : i32 + %15 = arith.index_cast %14 : i32 to index + pto.store %4, %arg0[%15] : !pto.ptr, f32 + } + pto.syncthreads + %6 = arith.cmpi ult, %1, %c32_i32 : i32 + %7 = scf.if %6 -> (f32) { + %13 = arith.cmpi ult, %3, %c4_i32 : i32 + %14 = scf.if %13 -> (f32) { + %16 = arith.index_cast %3 : i32 to index + %17 = pto.load %arg0[%16] : !pto.ptr -> f32 + scf.yield %17 : f32 + } else { + scf.yield %cst_0 : f32 + } + %15 = pto.redux_add %14 : f32 -> f32 + scf.yield %15 : f32 + } else { + scf.yield %cst_0 : f32 + } + %8 = arith.cmpi ult, %1, %c1_i32 : i32 + scf.if %8 { + %13 = arith.index_cast %1 : i32 to index + pto.store %7, %arg0[%13] : !pto.ptr, f32 + } + pto.syncthreads + %9 = arith.remui %1, %c1_i32 : i32 + %10 = arith.index_cast %9 : i32 to index + %11 = pto.load %arg0[%10] : !pto.ptr -> f32 + pto.syncthreads + %12 = arith.index_cast %0 : i32 to index + pto.store %11, %arg1[%12] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_sum/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_cross_sum/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/main.cpp new file mode 100644 index 000000000..05dad144b --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_cross_sum/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 128; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_max/compare.py b/test/vpto/cases/micro-op/simt/allreduce_warp_max/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_max/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_max/golden.py b/test/vpto/cases/micro-op/simt/allreduce_warp_max/golden.py new file mode 100644 index 000000000..76a227fb7 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_max/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 32 +EXPECTED = 1.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_max/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_warp_max/kernel.pto new file mode 100644 index 000000000..e22af0898 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_max/kernel.pto @@ -0,0 +1,21 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c32_i32, %c1_i32, %c1_i32_0>>>(%arg0) : (!pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant 0xFF800000 : f32 + %1 = pto.get_laneid : i32 + %2 = pto.redux_max %cst : f32 -> f32 + %3 = arith.index_cast %0 : i32 to index + pto.store %2, %arg0[%3] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_max/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_max/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_max/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_max/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_max/main.cpp new file mode 100644 index 000000000..c4fcda9b3 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_max/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 32; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_min/compare.py b/test/vpto/cases/micro-op/simt/allreduce_warp_min/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_min/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_min/golden.py b/test/vpto/cases/micro-op/simt/allreduce_warp_min/golden.py new file mode 100644 index 000000000..76a227fb7 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_min/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 32 +EXPECTED = 1.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_min/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_warp_min/kernel.pto new file mode 100644 index 000000000..d921ef1fe --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_min/kernel.pto @@ -0,0 +1,21 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c32_i32, %c1_i32, %c1_i32_0>>>(%arg0) : (!pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant 0x7F800000 : f32 + %1 = pto.get_laneid : i32 + %2 = pto.redux_min %cst : f32 -> f32 + %3 = arith.index_cast %0 : i32 to index + pto.store %2, %arg0[%3] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_min/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_min/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_min/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_min/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_min/main.cpp new file mode 100644 index 000000000..c4fcda9b3 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_min/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 32; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_sum/compare.py b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/compare.py new file mode 100644 index 000000000..83136b372 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/compare.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import sys, numpy as np + +def main(): + golden = np.fromfile("golden_out.bin", dtype=np.float32) + out = np.fromfile("out.bin", dtype=np.float32) + if golden.shape != out.shape or not np.allclose(golden, out, rtol=1e-5, atol=1e-5): + mismatches = np.nonzero(~np.isclose(golden, out, rtol=1e-5, atol=1e-5))[0] + idx = int(mismatches[0]) if mismatches.size else 0 + print(f"[ERROR] mismatch at idx={{idx}}, golden={{golden[idx]:.6f}}, out={{out[idx]:.6f}}") + sys.exit(2) + print("[INFO] compare passed") + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_sum/golden.py b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/golden.py new file mode 100644 index 000000000..4a5fb7b63 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/golden.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import argparse, numpy as np +from pathlib import Path + +NLANES = 32 +EXPECTED = 32.0 + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + out = np.zeros(NLANES, dtype=np.float32) + out.tofile(output_dir / "out.bin") + golden = np.full(NLANES, EXPECTED, dtype=np.float32) + golden.tofile(output_dir / "golden_out.bin") + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=Path, default=Path(".")) + a = p.parse_args() + generate(a.output_dir) + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_sum/kernel.pto b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/kernel.pto new file mode 100644 index 000000000..18e137438 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/kernel.pto @@ -0,0 +1,21 @@ +module attributes {pto.kernel_kind = #pto.kernel_kind, pto.mode = "auto", pto.target_arch = "a5"} { + func.func @_kernel(%arg0: !pto.ptr) attributes {pto.aicore} { + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + pto.simt_launch @_body__simt_0<<<%c32_i32, %c1_i32, %c1_i32_0>>>(%arg0) : (!pto.ptr) -> () + pto.barrier + return + } + func.func @_body__simt_0(%arg0: !pto.ptr) attributes {pto.simt_entry} { + %0 = pto.get_tid_x : i32 + %cst = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %1 = pto.get_laneid : i32 + %2 = pto.redux_add %cst : f32 -> f32 + %3 = arith.index_cast %0 : i32 to index + pto.store %2, %arg0[%3] : !pto.ptr, f32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_sum/launch.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/launch.cpp new file mode 100644 index 000000000..94b5e9414 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/launch.cpp @@ -0,0 +1,11 @@ +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void _kernel(__gm__ float *out); +void Launch_kernel(float *out, void *stream) { + _kernel<<<1, nullptr, stream>>>((__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/simt/allreduce_warp_sum/main.cpp b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/main.cpp new file mode 100644 index 000000000..c4fcda9b3 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/allreduce_warp_sum/main.cpp @@ -0,0 +1,43 @@ +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void Launch_kernel(float *out, void *stream); + +int main() { + size_t elemCount = 32; + size_t fileSize = elemCount * sizeof(float); + float *outHost = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false, deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); aclInited = true; + if (const char *e = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(e); + ACL_CHECK(aclrtSetDevice(deviceId)); deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + // Input: zero-initialize output buffer (kernel writes results) + std::memset(outHost, 0, fileSize); + ACL_CHECK(aclrtMemcpy(outDevice, fileSize, outHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + Launch_kernel(outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, fileSize, outDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out.bin", outHost, fileSize); + +cleanup: + aclrtFree(outDevice); aclrtFreeHost(outHost); + if (stream) aclrtDestroyStream(stream); + if (deviceSet) aclrtResetDevice(deviceId); + if (aclInited) aclFinalize(); + return rc; +}