From a6f092e75575858aa92d5ef275daa9b36263fe22 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 09:30:03 +0800 Subject: [PATCH 01/51] build: upgrade PTOAS to LLVM 21.1.8 --- .github/workflows/build_wheel.yml | 15 +- .github/workflows/build_wheel_mac.yml | 15 +- .github/workflows/ci.yml | 17 ++- .github/workflows/ci_sim.yml | 10 +- .gitignore | 2 + CMakeLists.txt | 6 +- README.md | 42 +++--- README_en.md | 16 +- ReleaseNotes.md | 2 +- docker/Dockerfile | 7 +- docs/build_with_installed_llvm.md | 9 +- docs/designs/ci-board-validation-guide.md | 2 +- include/PTO/IR/PTOTypeUtils.h | 2 + include/PTO/IR/VPTOOps.td | 4 +- lib/Bindings/Python/CMakeLists.txt | 37 ++++- lib/PTO/IR/PTO.cpp | 16 +- lib/PTO/IR/PTOTypeUtils.cpp | 12 +- .../BufferizableOpInterfaceImpl.cpp | 23 +-- lib/PTO/Transforms/ConvertToPTOOp.cpp | 12 +- lib/PTO/Transforms/ExpandTileOp.cpp | 4 +- lib/PTO/Transforms/InferPTOLayout.cpp | 6 +- .../InsertSync/InsertSyncAnalysis.cpp | 2 +- .../Transforms/InsertSync/PTOIRTranslator.cpp | 6 +- lib/PTO/Transforms/LoweringSyncToPipe.cpp | 2 +- .../Transforms/PTOMaterializeTileHandles.cpp | 8 +- lib/PTO/Transforms/PTOPlanMemory.cpp | 2 +- lib/PTO/Transforms/PTOToEmitC.cpp | 142 ++++++++++++------ lib/PTO/Transforms/PTOViewToMemref.cpp | 4 +- lib/PTO/Transforms/Utils.cpp | 6 +- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 48 ++---- lib/PTO/Transforms/VPTOExpandWrapperOps.cpp | 2 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 48 ++---- lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 7 +- lib/PTO/Transforms/VPTOPtrCastCleanup.cpp | 2 +- lib/PTO/Transforms/VPTOPtrNormalize.cpp | 3 +- ptodsl/tests/test_docs_as_test.py | 26 ++-- ptodsl/tests/test_jit_compile.py | 3 +- python/pto/dialects/pto.py | 27 ++-- .../Qwen3DecodeA5/down_proj_residual.pto | 36 ++--- .../Qwen3DecodeA5/out_proj_residual.pto | 36 ++--- .../Qwen3DecodeA5/qwen3_decode_incore_1.pto | 36 ++--- .../Qwen3DecodeA5/qwen3_decode_incore_10.pto | 48 +++--- .../Qwen3DecodeA5/qwen3_decode_incore_11.pto | 48 +++--- .../Qwen3DecodeA5/qwen3_decode_incore_2.pto | 72 ++++----- .../Qwen3DecodeA5/qwen3_decode_incore_4.pto | 24 +-- .../Qwen3DecodeA5/qwen3_decode_incore_6.pto | 24 +-- tilelang-dsl/python/tilelang_dsl/kernel.py | 21 ++- tools/ptoas/ObjectEmission.cpp | 2 +- tools/ptobc/src/mlir_encode.cpp | 35 ++++- tools/ptobc/src/ptobc_decode_print.cpp | 3 +- .../testdata/recent_mx_ops_v0_roundtrip.pto | 26 ++++ .../testdata/recent_ops_v0_roundtrip.pto | 12 -- tools/ptobc/tests/recent_ops_v0_encode.sh | 11 +- 53 files changed, 585 insertions(+), 446 deletions(-) create mode 100644 tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 209af90f95..f00d7321b1 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -24,9 +24,9 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto - LLVM_CACHE_FLAVOR: release-hardening-v1 + LLVM_REPO: https://github.com/llvm/llvm-project.git + LLVM_TAG: llvmorg-21.1.8 + LLVM_CACHE_FLAVOR: llvm21-release-hardening-v1 jobs: build_wheel: @@ -98,9 +98,9 @@ jobs: - name: Resolve LLVM source SHA id: llvm-source run: | - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" if [ -z "${LLVM_SOURCE_SHA}" ]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 exit 1 fi echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" @@ -121,7 +121,7 @@ jobs: git remote add origin "${LLVM_REPO}" fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_REF}" + git fetch --depth 1 origin "${LLVM_TAG}" git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs @@ -143,6 +143,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ + -DPython_EXECUTABLE=${PY_PATH}/bin/python \ + -Dpybind11_DIR="$(${PY_PATH}/bin/python -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(${PY_PATH}/bin/python -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C $LLVM_BUILD_DIR diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 028f8c74fa..2a8106f28c 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -23,9 +23,9 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto - LLVM_CACHE_FLAVOR: release-v2 + LLVM_REPO: https://github.com/llvm/llvm-project.git + LLVM_TAG: llvmorg-21.1.8 + LLVM_CACHE_FLAVOR: llvm21-release-v1 jobs: build_wheel: @@ -101,9 +101,9 @@ jobs: - name: Resolve LLVM source SHA id: llvm-source run: | - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" if [ -z "${LLVM_SOURCE_SHA}" ]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 exit 1 fi echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" @@ -124,7 +124,7 @@ jobs: git remote add origin "${LLVM_REPO}" fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_REF}" + git fetch --depth 1 origin "${LLVM_TAG}" git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs @@ -145,6 +145,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python) \ + -DPython_EXECUTABLE=$(which python) \ + -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C $LLVM_BUILD_DIR diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c186affebc..7edcceb531 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -97,14 +97,14 @@ jobs: runs-on: ubuntu-22.04 env: PTOAS_CLANG_MAJOR: "15" - LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto + LLVM_REPO: https://github.com/llvm/llvm-project.git + LLVM_TAG: llvmorg-21.1.8 LLVM_BUILD_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert PTO_BUILD_DIR: ${{ github.workspace }}/build-assert PTO_INSTALL_DIR: ${{ github.workspace }}/install-assert MLIR_PYTHONPATH: ${{ github.workspace }}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core - LLVM_CACHE_FLAVOR: assert-shared-mlirpy-hardening-v2 + LLVM_CACHE_FLAVOR: llvm21-assert-shared-mlirpy-hardening-v1 steps: - name: Checkout uses: actions/checkout@v4 @@ -122,7 +122,7 @@ jobs: libedit-dev zlib1g-dev libxml2-dev libzstd-dev python3 -m pip install --upgrade pip # LLVM/MLIR Python bindings are not yet compatible with pybind11 3.x. - python3 -m pip install setuptools wheel 'pybind11<3' numpy + python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy - name: Define payload paths shell: bash @@ -159,9 +159,9 @@ jobs: shell: bash run: | set -euo pipefail - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" [[ -n "${LLVM_SOURCE_SHA}" ]] || { - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 exit 1 } echo "LLVM_SOURCE_SHA=${LLVM_SOURCE_SHA}" >> "${GITHUB_ENV}" @@ -187,7 +187,7 @@ jobs: fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_REF}" + git fetch --depth 1 origin "${LLVM_TAG}" git checkout --force FETCH_HEAD - name: Build LLVM/MLIR (only if cache miss) @@ -200,6 +200,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ + -DPython_EXECUTABLE=python3 \ + -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER="${PTOAS_CMAKE_C_COMPILER}" \ -DCMAKE_CXX_COMPILER="${PTOAS_CMAKE_CXX_COMPILER}" \ diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index f8b361b273..cd41b22ba8 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -107,8 +107,8 @@ jobs: needs.detect-vpto-sim-changes.outputs.should_run == 'true' }} env: - LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto + LLVM_REPO: https://github.com/llvm/llvm-project.git + LLVM_TAG: llvmorg-21.1.8 PTO_INSTALL_DIR: ${{ github.workspace }}/install VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci @@ -193,11 +193,12 @@ jobs: python3 -c "import numpy" >/dev/null 2>&1 || need_pip_install=1 python3 -c "import setuptools, wheel" >/dev/null 2>&1 || need_pip_install=1 python3 -m pybind11 --cmakedir >/dev/null 2>&1 || need_pip_install=1 + python3 -m nanobind --cmake_dir >/dev/null 2>&1 || need_pip_install=1 python3 -c "import ml_dtypes" >/dev/null 2>&1 || need_pip_install=1 if [[ "${need_pip_install}" -eq 1 ]]; then python3 -m pip install --upgrade pip - python3 -m pip install setuptools wheel 'pybind11<3' numpy ml-dtypes + python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi - name: Clean CI work dirs @@ -251,6 +252,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ + -DPython_EXECUTABLE=python3 \ + -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" diff --git a/.gitignore b/.gitignore index a2cfb9d562..4b46f5aa43 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ # Build artifacts build/ +build-*/ build_plain/ build_plan/ install/ +install-*/ # TileLang ST standalone build outputs (see temp_docs/standalone_st.md) test/tilelang_st/npu/a5/src/st/build/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 23b6e94494..4908c3cbb1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,7 +39,7 @@ endif() message(STATUS "PTOAS CLI version: ${PTOAS_CLI_VERSION}") # ========================================================= -# [新增] 强制设置 C++17 标准 (LLVM 19 必需) +# [新增] 强制设置 C++17 标准 (LLVM/MLIR 必需) # ========================================================= set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -52,6 +52,8 @@ find_package(LLVM REQUIRED CONFIG) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "LLVM CMake Dir: ${LLVM_CMAKE_DIR}") message(STATUS "LLVM Include Dir: ${LLVM_INCLUDE_DIRS}") +get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR + "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) # 将 LLVM 模块路径加入 CMake 搜索路径 list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") @@ -98,7 +100,7 @@ include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) -link_directories(${LLVM_BUILD_LIBRARY_DIR}) +link_directories(${PTO_LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) # ========================================================= diff --git a/README.md b/README.md index 0d8399783f..467bf94372 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## 1. 项目简介 (Introduction) -**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR (llvmorg-19.1.7)***(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)* 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 +**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR (llvmorg-21.1.8)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 作为连接上层 AI 框架与底层各类NPU/GPGPU/CPU硬件,`ptoas` 采用 **Out-of-Tree** 架构构建,提供了完整的 C++ 与 Python 接口,主要职责包括: @@ -37,7 +37,7 @@ PTOAS/ ## 3. 构建指南 (Build Instructions) -⚠️ **重要提示**:本项目严格依赖 **LLVM llvmorg-19.1.7** 版本。 +⚠️ **重要提示**:本项目严格依赖 **LLVM llvmorg-21.1.8** 版本。 ### 3.0 环境变量配置 (Configuration) @@ -51,11 +51,11 @@ export WORKSPACE_DIR=$HOME/llvm-workspace # LLVM 源码与构建路径 export LLVM_SOURCE_DIR=$WORKSPACE_DIR/llvm-project -export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared +export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared-21 # PTOAS 源码与安装路径 export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS -export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-llvm21 # ======================================================= # 创建工作目录 @@ -69,21 +69,22 @@ mkdir -p $WORKSPACE_DIR * **Compiler**: GCC >= 9 或 Clang (支持 C++17) * **Build System**: CMake >= 3.20, Ninja * **Python**: 3.8+ -* **Python Packages**: `pybind11`, `numpy` +* **Python Packages**: `pybind11<3`, `nanobind`, `numpy` ```bash -python3 -m pip install pybind11==2.12.0 numpy +python3 -m pip install 'pybind11<3' nanobind numpy ``` -> 说明:当前 LLVM/MLIR Python 绑定与 `pybind11` 3.x 不兼容。 +> 说明:当前 PTOAS Python 扩展继续使用 `pybind11`,LLVM21 的 MLIR Python 绑定构建需要 `nanobind`。 +> 当前 LLVM/MLIR Python 绑定与 `pybind11` 3.x 不兼容。 > 如果编译 LLVM 时遇到 `def_property family does not currently support keep_alive` 等报错, -> 请先执行上面的降级命令。 +> 请确认使用上面的 `pybind11<3` 依赖。 ### 3.2 第一步:构建 LLVM/MLIR (Dependency) -我们需要下载 LLVM 源码,切换到 `llvmorg-19.1.7` 标签,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 +我们需要下载 LLVM 源码,切换到 `llvmorg-21.1.8` 标签,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 ```bash # 1. 下载 LLVM 源码 @@ -91,8 +92,8 @@ cd $WORKSPACE_DIR git clone https://github.com/llvm/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [关键] 切换到 llvmorg-19.1.7 -git checkout llvmorg-19.1.7 +# 2. [关键] 切换到 llvmorg-21.1.8 +git checkout llvmorg-21.1.8 # 3. 配置 CMake (构建动态库并启用 Python 绑定) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ @@ -100,6 +101,9 @@ cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ -DBUILD_SHARED_LIBS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python3) \ + -DPython_EXECUTABLE=$(which python3) \ + -Dpybind11_DIR=$(python3 -m pybind11 --cmakedir) \ + -Dnanobind_DIR=$(python3 -m nanobind --cmake_dir) \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" @@ -110,7 +114,7 @@ ninja -C $LLVM_BUILD_DIR ### 3.3 第二步:构建 PTOAS (Out-of-Tree) -下载 PTOAS 源码并基于刚刚编译好的 LLVM 19 进行构建。 +下载 PTOAS 源码并基于刚刚编译好的 LLVM 21 进行构建。 ```bash # 1. 下载 PTOAS 源码 @@ -125,7 +129,7 @@ export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) # 注意:此处直接使用了 3.0 章节中定义的变量,无需手动修改 cmake -G Ninja \ -S . \ - -B build \ + -B build-llvm21 \ -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ -DPython3_EXECUTABLE=$(which python3) \ @@ -136,12 +140,12 @@ cmake -G Ninja \ -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" # 4. 编译并安装 -ninja -C build -ninja -C build install +ninja -C build-llvm21 +ninja -C build-llvm21 install # 5. 检查构建产物 # build 输出(便于本地开发/调试) -$PTO_SOURCE_DIR/build/python/ +$PTO_SOURCE_DIR/build-llvm21/python/ ├── mlir │ ├── _mlir_libs │ │ └── _pto.cpython-*.so @@ -159,8 +163,8 @@ $PTO_INSTALL_DIR/ └── _pto.cpython-*.so # CLI 工具 -$PTO_SOURCE_DIR/build/tools/ptoas/ptoas -$PTO_SOURCE_DIR/build/tools/ptobc/ptobc +$PTO_SOURCE_DIR/build-llvm21/tools/ptoas/ptoas +$PTO_SOURCE_DIR/build-llvm21/tools/ptobc/ptobc ``` @@ -183,7 +187,7 @@ export PYTHONPATH=$PTO_PYTHON_ROOT:$MLIR_PYTHON_ROOT:$PYTHONPATH export LD_LIBRARY_PATH=$LLVM_BUILD_DIR/lib:$PTO_INSTALL_DIR/lib:$LD_LIBRARY_PATH # 3. PATH: 将 ptoas / ptobc 添加到命令行路径 -export PATH=$PTO_SOURCE_DIR/build/tools/ptoas:$PTO_SOURCE_DIR/build/tools/ptobc:$PATH +export PATH=$PTO_SOURCE_DIR/build-llvm21/tools/ptoas:$PTO_SOURCE_DIR/build-llvm21/tools/ptobc:$PATH ``` diff --git a/README_en.md b/README_en.md index b7a060a0fd..5528b1f688 100644 --- a/README_en.md +++ b/README_en.md @@ -2,7 +2,7 @@ ## 1. Introduction -**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-19.1.7)** *(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). +**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-21.1.8)** *(Commit 2078da43e25a4623cab2d0d60decddf709aaea28)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: @@ -36,7 +36,7 @@ PTOAS/ ## 3. Build Instructions -⚠️ **Important**: This project strictly requires **LLVM llvmorg-19.1.7**. +⚠️ **Important**: This project strictly requires **LLVM llvmorg-21.1.8**. ### 3.0 Environment Variable Configuration @@ -67,10 +67,10 @@ mkdir -p $WORKSPACE_DIR * **Compiler**: GCC >= 9 or Clang (C++17 support required) * **Build System**: CMake >= 3.20, Ninja * **Python**: 3.8+ -* **Python Packages**: `pybind11`, `numpy` +* **Python Packages**: `pybind11<3`, `nanobind`, `numpy` ```bash -python3 -m pip install pybind11==2.12.0 numpy +python3 -m pip install "pybind11<3" nanobind numpy ``` > **Note**: The current LLVM/MLIR Python bindings are not compatible with `pybind11` 3.x. @@ -79,7 +79,7 @@ python3 -m pip install pybind11==2.12.0 numpy ### 3.2 Step 1: Build LLVM/MLIR (Dependency) -Download the LLVM source, check out the `llvmorg-19.1.7` tag, and build with **shared libraries** to ensure correct linking for Python bindings. +Download the LLVM source, check out the `llvmorg-21.1.8` tag, and build with **shared libraries** to ensure correct linking for Python bindings. ```bash # 1. Clone LLVM @@ -87,8 +87,8 @@ cd $WORKSPACE_DIR git clone https://github.com/llvm/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [Critical] Check out llvmorg-19.1.7 -git checkout llvmorg-19.1.7 +# 2. [Critical] Check out llvmorg-21.1.8 +git checkout llvmorg-21.1.8 # 3. Configure CMake (build shared libs with Python bindings enabled) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ @@ -105,7 +105,7 @@ ninja -C $LLVM_BUILD_DIR ### 3.3 Step 2: Build PTOAS (Out-of-Tree) -Clone the PTOAS source and build against the LLVM 19 you just compiled. +Clone the PTOAS source and build against the LLVM 21 you just compiled. ```bash # 1. Clone PTOAS diff --git a/ReleaseNotes.md b/ReleaseNotes.md index f94982946c..8aa8592e7c 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -8,7 +8,7 @@ - PTOAS 首次发布 ## 概述 -PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR release/19.x 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 +PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR llvmorg-21.1.8 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 PTOAS很快将集成到以下框架中,敬请期待 - PyPTO diff --git a/docker/Dockerfile b/docker/Dockerfile index d9714f8327..7ecdebcf20 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ARG ARCH # NOTE: change $PY_VER for different Python versions (3.8 - 3.14 available) ARG PY_VER=cp311-cp311 -ARG LLVM_TAG=llvmorg-19.1.7 +ARG LLVM_TAG=llvmorg-21.1.8 ## -- usually no need to change below -- @@ -19,7 +19,7 @@ ENV PATH="${PY_PATH}/bin:${PATH}" # dependency RUN dnf install -y ninja-build cmake git ccache gcc-c++ lld zip binutils patchelf chrpath && dnf clean all -RUN pip install --no-cache-dir numpy pybind11 nanobind setuptools wheel auditwheel +RUN pip install --no-cache-dir numpy 'pybind11<3' nanobind setuptools wheel auditwheel COPY cmake/LinuxHardeningCache.cmake /tmp/LinuxHardeningCache.cmake @@ -43,6 +43,9 @@ RUN cmake -C /tmp/LinuxHardeningCache.cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ + -DPython_EXECUTABLE=${PY_PATH}/bin/python \ + -Dpybind11_DIR=$(${PY_PATH}/bin/python -m pybind11 --cmakedir) \ + -Dnanobind_DIR=$(${PY_PATH}/bin/python -m nanobind --cmake_dir) \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md index 6c7495a5ca..316b46a8a5 100644 --- a/docs/build_with_installed_llvm.md +++ b/docs/build_with_installed_llvm.md @@ -2,7 +2,7 @@ 本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: -- LLVM/MLIR `19.1.7` 已经构建并安装完成。 +- LLVM/MLIR `21.1.8` 已经构建并安装完成。 - LLVM 安装路径固定为 `/opt/llvm`。 - `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 @@ -42,11 +42,12 @@ mkdir -p "$WORKSPACE_DIR" - CMake >= 3.20 - Ninja - Python 3.8+ -- `pybind11` +- `pybind11<3` +- `nanobind` - `numpy` ```bash -pip3 install pybind11 numpy +pip3 install "pybind11<3" nanobind numpy ``` ## 跳过 3.2 @@ -62,7 +63,7 @@ README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 输出为: ```text -19.1.7 +21.1.8 ``` ## 3.3 第二步:构建 ptoas diff --git a/docs/designs/ci-board-validation-guide.md b/docs/designs/ci-board-validation-guide.md index b87feefd9f..969cfc6055 100644 --- a/docs/designs/ci-board-validation-guide.md +++ b/docs/designs/ci-board-validation-guide.md @@ -83,7 +83,7 @@ ```bash git clone https://github.com/llvm/llvm-project.git cd llvm-project -git checkout llvmorg-19.1.7 +git checkout llvmorg-21.1.8 cmake -G Ninja -S llvm -B llvm/build-shared \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ diff --git a/include/PTO/IR/PTOTypeUtils.h b/include/PTO/IR/PTOTypeUtils.h index f7db223fbc..45d0740487 100644 --- a/include/PTO/IR/PTOTypeUtils.h +++ b/include/PTO/IR/PTOTypeUtils.h @@ -15,6 +15,8 @@ namespace mlir::pto { bool isPTOFloat8Type(Type t); +bool isPTOFloat8E4M3LikeType(Type t); +bool isPTOFloat8E5M2LikeType(Type t); bool isPTOHiFloat8Type(Type t); bool isPTOF8E8M0Type(Type t); bool isPTOHiFloat8x2Type(Type t); diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 253e36a22a..2cd71fed22 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -72,13 +72,13 @@ def PTO_V2F8E4M3FNType : Type< CPred<"::llvm::isa<::mlir::VectorType>($_self) && " "::llvm::cast<::mlir::VectorType>($_self).getRank() == 1 && " "::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 2 && " - "::llvm::cast<::mlir::VectorType>($_self).getElementType().isFloat8E4M3FN()">, + "::llvm::isa<::mlir::Float8E4M3FNType>(::llvm::cast<::mlir::VectorType>($_self).getElementType())">, "vector<2xf8E4M3FN>">; def PTO_V2F8E5M2Type : Type< CPred<"::llvm::isa<::mlir::VectorType>($_self) && " "::llvm::cast<::mlir::VectorType>($_self).getRank() == 1 && " "::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 2 && " - "::llvm::cast<::mlir::VectorType>($_self).getElementType().isFloat8E5M2()">, + "::llvm::isa<::mlir::Float8E5M2Type>(::llvm::cast<::mlir::VectorType>($_self).getElementType())">, "vector<2xf8E5M2>">; def PTO_V2HiF8Type : Type< CPred<"::llvm::isa<::mlir::pto::HiF8x2Type>($_self)">, diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index c87bdb30c5..16ebd8d710 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -19,6 +19,9 @@ find_package(pybind11 CONFIG REQUIRED) include(TableGen) include(AddMLIR) +get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR + "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) + # ---- 1) Python native extension: mlir._mlir_libs._pto ---- pybind11_add_module(_pto MODULE PTOModule.cpp @@ -54,8 +57,8 @@ target_link_libraries(_pto PRIVATE # 关键:放到 mlir/_mlir_libs 下(匹配 MLIR dialect python 的 import 习惯) set_target_properties(_pto PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/python/mlir/_mlir_libs" - BUILD_RPATH "${LLVM_BUILD_LIBRARY_DIR}" - INSTALL_RPATH "${LLVM_BUILD_LIBRARY_DIR}" + BUILD_RPATH "${PTO_LLVM_BUILD_LIBRARY_DIR}" + INSTALL_RPATH "${PTO_LLVM_BUILD_LIBRARY_DIR}" ) # macOS: 避免 ld 警告 "-undefined suppress is deprecated",仅保留 -undefined dynamic_lookup if(APPLE) @@ -84,18 +87,38 @@ add_dependencies(_pto PTOPythonGen) # ---- 3) Copy generated python + handwritten pto.py into build/python ---- set(PTO_PY_SRC "${CMAKE_SOURCE_DIR}/python/pto/dialects/pto.py") - -add_custom_command(TARGET _pto POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/python/mlir/dialects" +set(PTO_PY_BUILD_DIR "${CMAKE_BINARY_DIR}/python/mlir/dialects") +set(PTO_PY_BUILD "${PTO_PY_BUILD_DIR}/pto.py") +set(PTO_OPS_PY_BUILD "${PTO_PY_BUILD_DIR}/_pto_ops_gen.py") + +add_custom_command( + OUTPUT + "${PTO_PY_BUILD}" + "${PTO_OPS_PY_BUILD}" + COMMAND ${CMAKE_COMMAND} -E make_directory "${PTO_PY_BUILD_DIR}" COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PTO_PY_SRC}" - "${CMAKE_BINARY_DIR}/python/mlir/dialects/pto.py" + "${PTO_PY_BUILD}" COMMAND ${CMAKE_COMMAND} -E copy_if_different "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" - "${CMAKE_BINARY_DIR}/python/mlir/dialects/_pto_ops_gen.py" + "${PTO_OPS_PY_BUILD}" + DEPENDS + "${PTO_PY_SRC}" + "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" VERBATIM ) +add_custom_target(PTOStagePythonModules ALL + DEPENDS + "${PTO_PY_BUILD}" + "${PTO_OPS_PY_BUILD}" +) +add_dependencies(PTOStagePythonModules PTOPythonGen) +add_dependencies(_pto PTOStagePythonModules) +if(TARGET PTOPythonModules) + add_dependencies(PTOPythonModules PTOStagePythonModules) +endif() + install(FILES "${PTO_PY_SRC}" "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 0712fdc894..d66746d7ae 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3655,7 +3655,7 @@ static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, SmallVector strides; int64_t offset = 0; - if (failed(getStridesAndOffset(memTy, strides, offset))) + if (failed(memTy.getStridesAndOffset(strides, offset))) return op->emitOpError() << "expects " << name << " to be a strided memref with a known layout"; @@ -3876,10 +3876,7 @@ static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { return true; if (!isTargetArchA5(op)) return false; - if (isPTOHiFloat8Type(ty)) - return true; - return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); + return isPTOHiFloat8Type(ty) || isPTOFloat8Type(ty); } static bool isSupportedMScatterAtomicPayloadElemType(Type ty, @@ -4041,8 +4038,7 @@ static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { return false; return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32() || isPTOHiFloat8Type(dstElem) || - dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || - dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); + isPTOFloat8E4M3LikeType(dstElem); } static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { @@ -6619,9 +6615,7 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxFp8InputType(Type ty) { - if (auto ft = dyn_cast(ty)) - return ft.isFloat8E4M3FN() || ft.isFloat8E5M2(); - return false; + return isa(ty); } static bool isA5MxInputTypePair(Type lhsTy, Type rhsTy) { @@ -12771,7 +12765,7 @@ mlir::LogicalResult mlir::pto::SimdTileToMemrefOp::verify() { SmallVector memStrides; int64_t memOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(memTy, memStrides, memOffset))) + if (failed(memTy.getStridesAndOffset(memStrides, memOffset))) return emitOpError("expects memref to use strided layout"); if (memOffset != 0) return emitOpError("expects memref offset to be 0"); diff --git a/lib/PTO/IR/PTOTypeUtils.cpp b/lib/PTO/IR/PTOTypeUtils.cpp index 610248b954..3591c05986 100644 --- a/lib/PTO/IR/PTOTypeUtils.cpp +++ b/lib/PTO/IR/PTOTypeUtils.cpp @@ -18,8 +18,16 @@ constexpr unsigned kBitsPerByte = 8; } // namespace bool mlir::pto::isPTOFloat8Type(Type t) { - return t.isFloat8E4M3() || t.isFloat8E4M3FN() || t.isFloat8E4M3FNUZ() || - t.isFloat8E4M3B11FNUZ() || t.isFloat8E5M2() || t.isFloat8E5M2FNUZ(); + return isPTOFloat8E4M3LikeType(t) || isPTOFloat8E5M2LikeType(t); +} + +bool mlir::pto::isPTOFloat8E4M3LikeType(Type t) { + return isa(t); +} + +bool mlir::pto::isPTOFloat8E5M2LikeType(Type t) { + return isa(t); } bool mlir::pto::isPTOHiFloat8Type(Type t) { return isa(t); } diff --git a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp index 21d4b98305..f58a15a376 100644 --- a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp @@ -25,7 +25,7 @@ namespace { static LogicalResult bufferizeDestinationStyleOpInterface( RewriterBase &rewriter, DestinationStyleOpInterface op, - const BufferizationOptions &options, + const BufferizationOptions &options, const BufferizationState &state, bool supportMixedTensorBufferMode = true); template { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + const BufferizationState &state) const { return bufferizeDestinationStyleOpInterface( - rewriter, cast(op), options, + rewriter, cast(op), options, state, supportMixedTensorBufferMode); } }; @@ -59,7 +60,8 @@ struct PTOReadWriteDpsOpInterfaceBase /// Generic conversion for any DestinationStyleOpInterface on tensors. static LogicalResult bufferizeDestinationStyleOpInterface( RewriterBase &rewriter, DestinationStyleOpInterface op, - const BufferizationOptions &options, bool supportMixedTensorBufferMode) { + const BufferizationOptions &options, const BufferizationState &state, + bool supportMixedTensorBufferMode) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -83,7 +85,8 @@ static LogicalResult bufferizeDestinationStyleOpInterface( newOperands.push_back(opOperand.get()); continue; } - FailureOr buffer = getBuffer(rewriter, opOperand.get(), options); + FailureOr buffer = + getBuffer(rewriter, opOperand.get(), options, state); if (failed(buffer)) { return failure(); } @@ -95,7 +98,7 @@ static LogicalResult bufferizeDestinationStyleOpInterface( for (OpResult opResult : op->getOpResults()) { OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); FailureOr resultBuffer = - getBuffer(rewriter, opOperand->get(), options); + getBuffer(rewriter, opOperand->get(), options, state); if (failed(resultBuffer)) { return failure(); } @@ -120,13 +123,15 @@ struct PTOStoreOpInterface : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + const BufferizationState &state) const { auto dpsOp = cast(op); if (dpsOp.hasPureBufferSemantics()) { return success(); } if (dpsOp.hasPureTensorSemantics()) { - return bufferizeDestinationStyleOpInterface(rewriter, dpsOp, options); + return bufferizeDestinationStyleOpInterface(rewriter, dpsOp, options, + state); } // We only handle the case where fixpipe op's input is a tensor from // mmad and fixpipe op's output is a memref type. @@ -141,7 +146,7 @@ struct PTOStoreOpInterface OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - FailureOr buffer = getBuffer(rewriter, srcOp->get(), options); + FailureOr buffer = getBuffer(rewriter, srcOp->get(), options, state); if (failed(buffer)) { return failure(); } diff --git a/lib/PTO/Transforms/ConvertToPTOOp.cpp b/lib/PTO/Transforms/ConvertToPTOOp.cpp index ca3c18ce2a..7e72901f56 100644 --- a/lib/PTO/Transforms/ConvertToPTOOp.cpp +++ b/lib/PTO/Transforms/ConvertToPTOOp.cpp @@ -53,11 +53,11 @@ std::optional getLeftPadNum(PatternRewriter &rewriter, if (auto subviewOp = llvm::dyn_cast(user)) { auto offsets = subviewOp.getMixedOffsets(); auto offset = offsets[offsets.size() - 1]; - Value offsetValue = - offset.is() - ? dyn_cast(offset) - : rewriter.create( - subviewOp->getLoc(), getConstantIntValue(offset).value()); + Value offsetValue = dyn_cast(offset); + if (!offsetValue) { + offsetValue = rewriter.create( + subviewOp->getLoc(), getConstantIntValue(offset).value()); + } return offsetValue; } } @@ -199,7 +199,7 @@ void ConvertToPTOOpPass::runOnOperation() { moduleOp->walk([&](func::FuncOp funcOp) { RewritePatternSet patterns(ctx); populatePTOOpRewritingRule(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); }); } diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 0e32e132c7..3045901099 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -596,7 +596,7 @@ static void populateViewShapeAndStrides(Value value, shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); if (strides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(memrefTy, strides, offset))) { + if (succeeded(memrefTy.getStridesAndOffset(strides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } @@ -644,7 +644,7 @@ static std::optional buildOperandTypeInfo(Value value) { info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); if (info.viewStrides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { + if (succeeded(mrTy.getStridesAndOffset(info.viewStrides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index e6272cc46c..a485272a17 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -69,7 +69,7 @@ static std::optional getConstInt(OpFoldResult ofr) { return ia.getInt(); return std::nullopt; } - return getConstInt(ofr.get()); + return getConstInt(ofr.dyn_cast()); } static unsigned elemByteSize(Type ty) { @@ -287,7 +287,7 @@ static std::optional inferFromStaticMemRefTy(MemRefType mrTy) { return std::nullopt; SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(mrTy, strideInts, offset))) + if (failed(mrTy.getStridesAndOffset(strideInts, offset))) return std::nullopt; if (offset == ShapedType::kDynamic || llvm::any_of(strideInts, @@ -632,7 +632,7 @@ struct InferPTOLayoutPass SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcTy, strideInts, offset)) || + if (failed(srcTy.getStridesAndOffset(strideInts, offset)) || offset == ShapedType::kDynamic || llvm::any_of(strideInts, [](int64_t s) { return s == ShapedType::kDynamic; })) { diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 0d93b1d61c..564094de20 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -113,7 +113,7 @@ static std::optional getKnownBLayout(Type ty) { if (auto memRefTy = dyn_cast(ty)) { SmallVector strides; int64_t offset = 0; - if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + if (failed(memRefTy.getStridesAndOffset(strides, offset)) || strides.size() != 2) { return std::nullopt; } diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index d190885ec6..63d8b34eca 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -153,7 +153,7 @@ getMemrefSubViewBaseAddresses(memref::SubViewOp op, MemRefType sourceType, SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(mlir::getStridesAndOffset(sourceType, strides, baseOffset)) || + if (failed(sourceType.getStridesAndOffset(strides, baseOffset)) || strides.size() != 2 || llvm::is_contained(strides, ShapedType::kDynamic)) return std::nullopt; @@ -257,7 +257,7 @@ static std::pair getStaticOffsetAndSize(Operation *op, Value s if (auto subView = dyn_cast(op)) { int64_t baseOffset; StrideVector strides; - if (failed(mlir::getStridesAndOffset(srcType, strides, baseOffset))) { + if (failed(srcType.getStridesAndOffset(strides, baseOffset))) { return {-1, -1}; } @@ -944,7 +944,7 @@ void PTOIRTranslator::UpdateMemrefSubViewAliasBufferInfo(memref::SubViewOp op) { SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(mlir::getStridesAndOffset(sourceType, strides, baseOffset)) || + if (failed(sourceType.getStridesAndOffset(strides, baseOffset)) || strides.size() != 2) { UpdateConservativeAliasBufferInfo(result, source); return; diff --git a/lib/PTO/Transforms/LoweringSyncToPipe.cpp b/lib/PTO/Transforms/LoweringSyncToPipe.cpp index bc2672a0ba..15d02d7a98 100644 --- a/lib/PTO/Transforms/LoweringSyncToPipe.cpp +++ b/lib/PTO/Transforms/LoweringSyncToPipe.cpp @@ -138,7 +138,7 @@ struct LoweringSyncToPipe patterns.add( context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 4d868a20f6..7b1a85c3f4 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -360,8 +360,8 @@ getMaterializedTileShape(MemRefType memTy, const TileHandleMetadata &meta) { SmallVector inheritedStrides; int64_t inheritedOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(sourceMrTy, inheritedStrides, - inheritedOffset)) || + if (failed(sourceMrTy.getStridesAndOffset(inheritedStrides, + inheritedOffset)) || inheritedStrides.size() < 2) return shape; @@ -434,7 +434,7 @@ static Value materializeOffset(OpFoldResult ofr, OpBuilder &builder, return makeI64Constant(builder, loc, intAttr.getInt()); return Value(); } - return ensureI64(ofr.get(), builder, loc); + return ensureI64(cast(ofr), builder, loc); } static Value addI64(Value lhs, Value rhs, OpBuilder &builder, Location loc) { @@ -474,7 +474,7 @@ static Value computeSubviewAddress(memref::SubViewOp subview, SmallVector sourceStrides; int64_t sourceOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(sourceTy, sourceStrides, sourceOffset))) + if (failed(sourceTy.getStridesAndOffset(sourceStrides, sourceOffset))) return Value(); auto mixedOffsets = subview.getMixedOffsets(); diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index aa3196a672..adae1f8168 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -2374,7 +2374,7 @@ void PlanMemoryPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateBufferAddressToAllocOp(patterns, memPlan.GetBuffer2Offsets()); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 963b01c89c..c0a99f13ba 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -123,6 +123,20 @@ static pto::AddressSpace getAddressSpaceOrGM(Attribute memorySpace) { return pto::AddressSpace::GM; } +static Type getEmitCVariableResultType(Type valueType) { + if (isa(valueType)) + return valueType; + return emitc::LValueType::get(valueType); +} + +static Value loadEmitCVariableIfNeeded(OpBuilder &builder, Location loc, + Value value) { + if (auto lvalueTy = dyn_cast(value.getType())) + return builder.create(loc, lvalueTy.getValueType(), value) + .getResult(); + return value; +} + [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = "__pto.lowered_set_validshape"; [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = @@ -415,7 +429,7 @@ getGatherScatterShapeLayoutInfo(Type ty) { SmallVector strides; int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + if (failed(memRefTy.getStridesAndOffset(strides, offset)) || strides.size() != 2) return std::nullopt; @@ -511,12 +525,9 @@ static bool isF8E8M0ElemType(Type elemTy) { } static std::string getEmitCScalarTypeToken(Type elemTy) { - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) + if (pto::isPTOFloat8E4M3LikeType(elemTy)) return "float8_e4m3_t"; - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) + if (pto::isPTOFloat8E5M2LikeType(elemTy)) return "float8_e5m2_t"; if (isF8E8M0ElemType(elemTy)) return "float8_e8m0_t"; @@ -858,10 +869,9 @@ class PTOToEmitCTypeConverter : public TypeConverter { // 1. 基本类型 (f32, i32, index) // --------------------------------------------------------- addConversion([Ctx](FloatType type) -> Type { - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); @@ -1066,8 +1076,6 @@ class PTOToEmitCTypeConverter : public TypeConverter { addSourceMaterialization(materializeCast); addTargetMaterialization(materializeCast); - // Needed for region/block signature conversions (e.g. CFG block args). - addArgumentMaterialization(materializeCast); } }; @@ -3480,7 +3488,7 @@ struct SubviewToEmitCPattern : public OpConversionPattern { if (auto intAttr = dyn_cast(attr)) return intAttr.getInt(); } else { - Value v = ofr.get(); + Value v = ofr.dyn_cast(); if (auto cOp = v.getDefiningOp()) { if (auto iAttr = dyn_cast(cOp.getValue())) return iAttr.getInt(); @@ -3570,7 +3578,8 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } else { SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); + bool useTypeStrides = + succeeded(srcType.getStridesAndOffset(strideInts, offset)); (void)offset; if (useTypeStrides) { for (int64_t s : strideInts) { @@ -3941,7 +3950,7 @@ static bool hasStaticShape(MemRefType mrTy) { static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, int64_t &offset) { - if (failed(getStridesAndOffset(mrTy, strides, offset))) { + if (failed(mrTy.getStridesAndOffset(strides, offset))) { strides.clear(); int64_t stride = 1; ArrayRef shape = mrTy.getShape(); @@ -4349,9 +4358,11 @@ static FailureOr buildAsyncScratchTileValue( Value tile = rewriter .create( - loc, emitc::OpaqueType::get(ctx, tileTypeStr), + loc, getEmitCVariableResultType( + emitc::OpaqueType::get(ctx, tileTypeStr)), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); Value scratchAddr = rewriter @@ -4411,9 +4422,11 @@ static FailureOr buildSyncAllWorkspaceTileValue( auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); Value tile = rewriter - .create(loc, tileEmitTy, + .create(loc, + getEmitCVariableResultType(tileEmitTy), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); Value rawPtr = workspace; auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); @@ -4752,10 +4765,10 @@ struct PointerCastConversion : public OpConversionPattern { // 静态情况 (Tile v;) auto varOp = rewriter.create( loc, - tileType, + getEmitCVariableResultType(tileType), emitc::OpaqueAttr::get(ctx, "") ); - resultValue = varOp.getResult(); + resultValue = loadEmitCVariableIfNeeded(rewriter, loc, varOp.getResult()); } // TASSIGN: pto-isa expects an integral address. @@ -6860,8 +6873,10 @@ struct PTOBuildAsyncSessionToEmitC Value session = rewriter .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + loc, getEmitCVariableResultType(sessionTy), + emitc::OpaqueAttr::get(ctx, "")) .getResult(); + session = loadEmitCVariableIfNeeded(rewriter, loc, session); auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); @@ -6890,12 +6905,13 @@ struct PTOBuildAsyncSessionToEmitC Value baseConfig = rewriter .create( - loc, baseConfigTy, + loc, getEmitCVariableResultType(baseConfigTy), emitc::OpaqueAttr::get( ctx, "{" + std::to_string(blockBytes) + "ULL, " + std::to_string(commBlockOffset) + "ULL, " + std::to_string(queueNum) + "u}")) .getResult(); + baseConfig = loadEmitCVariableIfNeeded(rewriter, loc, baseConfig); rewriter.create( loc, TypeRange{}, "pto::comm::BuildAsyncSession", @@ -7025,7 +7041,7 @@ static FailureOr buildCollectiveParallelGroup( firstTy); auto groupArray = cast>( rewriter - .create(loc, arrayTy, + .create(loc, getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(ctx, "{}")) .getResult()); @@ -7429,9 +7445,10 @@ struct PTODeclareGlobalToEmitC } } auto var = rewriter.create( - op.getLoc(), convertedType, + op.getLoc(), getEmitCVariableResultType(convertedType), emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); + rewriter.replaceOp( + op, loadEmitCVariableIfNeeded(rewriter, op.getLoc(), var.getResult())); return success(); } }; @@ -7452,9 +7469,10 @@ struct PTODeclareEventIdArrayToEmitC auto array = rewriter .create( - op.getLoc(), arrayTy, + op.getLoc(), getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(rewriter.getContext(), "")) .getResult(); + array = loadEmitCVariableIfNeeded(rewriter, op.getLoc(), array); rewriter.replaceOp(op, array); return success(); } @@ -7521,9 +7539,10 @@ struct PTODeclareLocalArrayToEmitC auto var = rewriter .create( - op.getLoc(), arrayTy, + op.getLoc(), getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(rewriter.getContext(), "")) .getResult(); + var = loadEmitCVariableIfNeeded(rewriter, op.getLoc(), var); rewriter.replaceOp(op, var); return success(); } @@ -7557,11 +7576,12 @@ struct PTOLocalArrayGetToEmitC auto snapshot = rewriter .create( - op.getLoc(), resultTy, + op.getLoc(), getEmitCVariableResultType(resultTy), emitc::OpaqueAttr::get(rewriter.getContext(), "")) .getResult(); rewriter.create(op.getLoc(), snapshot, sub.getResult()); - rewriter.replaceOp(op, snapshot); + rewriter.replaceOp( + op, loadEmitCVariableIfNeeded(rewriter, op.getLoc(), snapshot)); return success(); } }; @@ -8083,9 +8103,11 @@ struct ReinterpretCastToEmitC : public OpConversionPattern(loc, tileType, + .create(loc, + getEmitCVariableResultType(tileType), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); // Compute an integer address and assign it to the new tile. // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). @@ -9834,13 +9856,24 @@ struct PTOQuantToEmitC : public OpConversionPattern { Value offsetPtr; if (op.getOffset()) { Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); + Type offsetValueTy = offset.getType(); + Value offsetLValue = offset; + if (auto lvalueTy = dyn_cast(offsetValueTy)) { + offsetValueTy = lvalueTy.getValueType(); + } else { + offsetLValue = + rewriter + .create( + loc, getEmitCVariableResultType(offsetValueTy), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + rewriter.create(loc, offsetLValue, offset); } + offsetPtr = + rewriter + .create( + loc, emitc::PointerType::get(offsetValueTy), "&", offsetLValue) + .getResult(); } // Optional tmp tile: when supplied it selects the tmp-aware 5-arg TQUANT @@ -12053,8 +12086,8 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (!pto::isPTOFloat4PackedType(elemTy) && subRows != ShapedType::kDynamic && subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && + succeeded(subMrTy.getStridesAndOffset(inheritedStrides, + inheritedOffset)) && inheritedStrides.size() >= 2) { int64_t childRowStride = 0; int64_t childColStride = 0; @@ -12245,9 +12278,12 @@ struct PTOBindTileToEmitC : public OpConversionPattern { .getResult(0); } - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(tileType), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); }; auto emitElemTypeToString = [&](Type elemTy) -> std::string { @@ -12475,8 +12511,10 @@ struct PTOAllocTileToEmitC tile = rewriter .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); } Value addr = adaptor.getAddr(); @@ -12519,10 +12557,12 @@ createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, if (!convertedTy) convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); } struct PTOTReshapeToEmitC : public OpConversionPattern { @@ -12719,10 +12759,12 @@ struct PTOMaterializeTileToEmitC .getResult(0); } - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); }; if (!isSubview && !forceDynamicValid && isTileLike(source)) { @@ -13751,7 +13793,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); - populateSCFToEmitCConversionPatterns(patterns); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Keep CFG-style branches type-consistent when block argument types are // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); @@ -13966,7 +14008,7 @@ static AICORE inline void ptoas_auto_sync_tail( scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + (void)applyPatternsGreedily(mop, std::move(scfLoweringPatterns)); bool hasUnsupportedSCF = false; mop.walk([&](Operation *op) { diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 0f445970a0..941a889924 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1442,7 +1442,7 @@ static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { SmallVector srcStrides; int64_t srcOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) + if (failed(srcMrTy.getStridesAndOffset(srcStrides, srcOffset))) srcStrides = computeCompactStrides(srcMrTy.getShape()); // Keep parent physical shape + strides for bound tile semantics. @@ -1996,7 +1996,7 @@ struct PTOViewToMemrefPass SmallVector staticStrides; int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(mrTy, staticStrides, offset)) && + if (succeeded(mrTy.getStridesAndOffset(staticStrides, offset)) && dimIndex < (int64_t)staticStrides.size() && staticStrides[dimIndex] != ShapedType::kDynamic) { rewriter.replaceOpWithNewOp( diff --git a/lib/PTO/Transforms/Utils.cpp b/lib/PTO/Transforms/Utils.cpp index 58e68c77e2..e328ec125a 100644 --- a/lib/PTO/Transforms/Utils.cpp +++ b/lib/PTO/Transforms/Utils.cpp @@ -106,12 +106,10 @@ std::optional> getOperationAliasInfo(Operation *op) { dyn_cast(op)) { return std::make_pair(extractStridedMetadataOp.getBaseBuffer(), extractStridedMetadataOp.getViewSource()); - } else if (auto toMemrefOp = dyn_cast(op)) { - return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); + } else if (auto toBufferOp = dyn_cast(op)) { + return std::make_pair(toBufferOp.getBuffer(), toBufferOp.getTensor()); } else if (auto toTensorOp = dyn_cast(op)) { return std::make_pair(toTensorOp.getResult(), toTensorOp.getOperand()); - } else if (auto toMemrefOp = dyn_cast(op)) { - return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); } return std::nullopt; } diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 8362aea64b..953b75d88d 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -68,31 +68,28 @@ static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { if (pto::isPTOHiFloat8Type(type)) - return LLVM::LLVMHiFloat8Type::get(context); + return Float8E4M3FNType::get(context); if (isa(type)) - return LLVM::LLVMFloat4E1M2x2Type::get(context); + return IntegerType::get(context, 8); if (isa(type)) - return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) - return LLVM::LLVMFloat8E4M3Type::get(context); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) - return LLVM::LLVMFloat8E5M2Type::get(context); + return IntegerType::get(context, 8); + if (pto::isPTOFloat8E4M3LikeType(type)) + return Float8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(type)) + return Float8E5M2Type::get(context); return {}; } static Type getLLVMCompatibleVectorType(ArrayRef shape, Type elementType, ArrayRef scalableDims = {}) { - if (shape.size() == 1 && !elementType.isIntOrIndexOrFloat()) - return LLVM::LLVMFixedVectorType::get(elementType, shape.front()); return VectorType::get(shape, elementType, scalableDims); } static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) - return LLVM::LLVMFixedVectorType::get( - LLVM::LLVMHiFloat8Type::get(builder.getContext()), 2); + return getLLVMCompatibleVectorType( + {2}, Float8E4M3FNType::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -177,12 +174,6 @@ static unsigned getNaturalByteAlignment(Type type) { elems *= dim; return elemAlign * static_cast(elems); } - if (auto vecType = dyn_cast(type)) { - unsigned elemAlign = getNaturalByteAlignment(vecType.getElementType()); - if (!elemAlign) - return 0; - return elemAlign * vecType.getNumElements(); - } if (auto intType = dyn_cast(type)) return llvm::divideCeil(unsigned(intType.getWidth()), 8u); if (pto::isPTOHiFloat8x2Type(type)) @@ -227,7 +218,6 @@ class VPTOTypeConverter final : public TypeConverter { }); addSourceMaterialization(materializeVPTOCast); addTargetMaterialization(materializeVPTOCast); - addArgumentMaterialization(materializeVPTOCast); } }; @@ -364,8 +354,7 @@ static std::string getMadRhsFragment(Type type) { } static bool isMadE4M3ElementType(Type type) { - return type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ(); + return pto::isPTOFloat8E4M3LikeType(type); } static bool isMadE5M2ElementType(Type type) { @@ -541,10 +530,9 @@ static std::string getLowPrecisionElementFragment(Type type) { return "f4e1m2x2"; if (isa(type)) return "f4e2m1x2"; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return "f8e4m3"; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return "f8e5m2"; return {}; } @@ -712,8 +700,6 @@ static Type getElementTypeFromVectorLike(Type type) { return vecType.getElementType(); if (auto vecType = dyn_cast(type)) return vecType.getElementType(); - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); return {}; } @@ -725,8 +711,6 @@ static std::optional getElementCountFromVectorLike(Type type) { return std::nullopt; return vecType.getShape().front(); } - if (auto vecType = dyn_cast(type)) - return vecType.getNumElements(); return std::nullopt; } @@ -8286,7 +8270,7 @@ class LowerKeepOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{}, payloads, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("R", payloads.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); for (pto::KeepOp keep : llvm::reverse(keepOps)) @@ -8359,7 +8343,7 @@ class LowerResumeOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{asmResultType}, ValueRange{}, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("=R", resumeOps.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); @@ -10472,7 +10456,6 @@ static void applySimtEntryCallingConvention( const llvm::StringSet &simtEntryNames) { for (llvm::Function &function : llvmModule) { if (simtEntryNames.contains(function.getName())) { - function.setCallingConv(llvm::CallingConv::SimtEntry); function.addFnAttr(llvm::Attribute::NoInline); // Match Bisheng's C++ frontend shape for SIMT outlined bodies. The // exported wrapper owns the real kernel metadata, while the SIMT body is @@ -10494,7 +10477,6 @@ static void applySimtEntryCallingConvention( auto *callee = call->getCalledFunction(); if (!callee || !simtEntryNames.contains(callee->getName())) continue; - call->setCallingConv(llvm::CallingConv::SimtEntry); } } } @@ -10552,7 +10534,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, kernelModulePM.addPass( std::make_unique()); kernelModulePM.addPass(arith::createArithExpandOpsPass()); - kernelModulePM.addPass(createConvertSCFToCFPass()); + kernelModulePM.addPass(createSCFToControlFlowPass()); kernelModulePM.addPass(createArithToLLVMConversionPass()); kernelModulePM.addPass(createConvertIndexToLLVMPass()); kernelModulePM.addPass(createFinalizeMemRefToLLVMConversionPass()); diff --git a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp index 60f4d18ce7..799cbc1b0b 100644 --- a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp +++ b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp @@ -1885,7 +1885,7 @@ struct VPTOExpandWrapperOpsPass ExpandMadSemanticPattern, ExpandMadSemanticPattern, ExpandMadSemanticPattern>(&getContext()); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 35f8cc51a3..f41169c796 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -69,31 +69,28 @@ static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { if (pto::isPTOHiFloat8Type(type)) - return LLVM::LLVMHiFloat8Type::get(context); + return Float8E4M3FNType::get(context); if (isa(type)) - return LLVM::LLVMFloat4E1M2x2Type::get(context); + return IntegerType::get(context, 8); if (isa(type)) - return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) - return LLVM::LLVMFloat8E4M3Type::get(context); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) - return LLVM::LLVMFloat8E5M2Type::get(context); + return IntegerType::get(context, 8); + if (pto::isPTOFloat8E4M3LikeType(type)) + return Float8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(type)) + return Float8E5M2Type::get(context); return {}; } static Type getLLVMCompatibleVectorType(ArrayRef shape, Type elementType, ArrayRef scalableDims = {}) { - if (shape.size() == 1 && !elementType.isIntOrIndexOrFloat()) - return LLVM::LLVMFixedVectorType::get(elementType, shape.front()); return VectorType::get(shape, elementType, scalableDims); } static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) - return LLVM::LLVMFixedVectorType::get( - LLVM::LLVMHiFloat8Type::get(builder.getContext()), 2); + return getLLVMCompatibleVectorType( + {2}, Float8E4M3FNType::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -178,12 +175,6 @@ static unsigned getNaturalByteAlignment(Type type) { elems *= dim; return elemAlign * static_cast(elems); } - if (auto vecType = dyn_cast(type)) { - unsigned elemAlign = getNaturalByteAlignment(vecType.getElementType()); - if (!elemAlign) - return 0; - return elemAlign * vecType.getNumElements(); - } if (auto intType = dyn_cast(type)) return llvm::divideCeil(unsigned(intType.getWidth()), 8u); if (pto::isPTOHiFloat8x2Type(type)) @@ -228,7 +219,6 @@ class VPTOTypeConverter final : public TypeConverter { }); addSourceMaterialization(materializeVPTOCast); addTargetMaterialization(materializeVPTOCast); - addArgumentMaterialization(materializeVPTOCast); } }; @@ -365,8 +355,7 @@ static std::string getMadRhsFragment(Type type) { } static bool isMadE4M3ElementType(Type type) { - return type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ(); + return pto::isPTOFloat8E4M3LikeType(type); } static bool isMadE5M2ElementType(Type type) { @@ -498,10 +487,9 @@ static std::string getLowPrecisionElementFragment(Type type) { return "f4e1m2x2"; if (isa(type)) return "f4e2m1x2"; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return "f8e4m3"; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return "f8e5m2"; return {}; } @@ -679,8 +667,6 @@ static Type getElementTypeFromVectorLike(Type type) { return vecType.getElementType(); if (auto vecType = dyn_cast(type)) return vecType.getElementType(); - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); return {}; } @@ -692,8 +678,6 @@ static std::optional getElementCountFromVectorLike(Type type) { return std::nullopt; return vecType.getShape().front(); } - if (auto vecType = dyn_cast(type)) - return vecType.getNumElements(); return std::nullopt; } @@ -8228,7 +8212,7 @@ class LowerKeepOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{}, payloads, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("R", payloads.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); for (pto::KeepOp keep : llvm::reverse(keepOps)) @@ -8302,7 +8286,7 @@ class LowerResumeOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{asmResultType}, ValueRange{}, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("=R", resumeOps.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); @@ -10433,7 +10417,6 @@ static void applySimtEntryCallingConvention( const llvm::StringSet &simtEntryNames) { for (llvm::Function &function : llvmModule) { if (simtEntryNames.contains(function.getName())) { - function.setCallingConv(llvm::CallingConv::SimtEntry); function.addFnAttr(llvm::Attribute::NoInline); // Match Bisheng's C++ frontend shape for SIMT outlined bodies. The // exported wrapper owns the real kernel metadata, while the SIMT body is @@ -10455,7 +10438,6 @@ static void applySimtEntryCallingConvention( auto *callee = call->getCalledFunction(); if (!callee || !simtEntryNames.contains(callee->getName())) continue; - call->setCallingConv(llvm::CallingConv::SimtEntry); } } } @@ -10513,7 +10495,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, kernelModulePM.addPass( std::make_unique()); kernelModulePM.addPass(arith::createArithExpandOpsPass()); - kernelModulePM.addPass(createConvertSCFToCFPass()); + kernelModulePM.addPass(createSCFToControlFlowPass()); kernelModulePM.addPass(createArithToLLVMConversionPass()); kernelModulePM.addPass(createConvertIndexToLLVMPass()); kernelModulePM.addPass(createFinalizeMemRefToLLVMConversionPass()); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp index 8b5eb7a165..402c4272eb 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -569,13 +569,14 @@ void attachHIVMKernelAnnotations(llvm::Module &llvmModule, simtConfigByName[symName] = {maxThreads, maxRegisters}; }); - auto callsSimtEntry = [](llvm::Function &function) { + auto callsSimtEntry = [&](llvm::Function &function) { for (llvm::BasicBlock &block : function) { for (llvm::Instruction &inst : block) { auto *call = llvm::dyn_cast(&inst); if (!call) continue; - if (call->getCallingConv() == llvm::CallingConv::SimtEntry) + auto *callee = call->getCalledFunction(); + if (callee && simtConfigByName.contains(callee->getName())) return true; } } @@ -609,7 +610,7 @@ void attachHIVMKernelAnnotations(llvm::Module &llvmModule, for (llvm::Function &function : llvmModule) { if (function.isDeclaration()) continue; - if (function.getCallingConv() == llvm::CallingConv::SimtEntry) { + if (simtConfigByName.contains(function.getName())) { uint32_t maxThreads = kDefaultSimtMaxThreads; uint32_t maxRegisters = kDefaultSimtMaxRegisters; if (auto it = simtConfigByName.find(function.getName()); diff --git a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp index dbaa9a780d..fc238fac5f 100644 --- a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp +++ b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp @@ -69,7 +69,7 @@ struct VPTOPtrCastCleanupPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index a692f47271..7ca37cbb40 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -117,7 +117,7 @@ static LogicalResult computeSubviewElementOffset(memref::SubViewOp op, SmallVector strides; int64_t baseOffset = 0; - if (failed(getStridesAndOffset(sourceType, strides, baseOffset))) + if (failed(sourceType.getStridesAndOffset(strides, baseOffset))) return failure(); // The SSA source already names the base address after bind_tile/pointer_cast // normalization. A dynamic memref layout offset here is metadata we can @@ -773,7 +773,6 @@ struct VPTOPtrNormalizePass [](Type type) { return convertSubviewResultType(type); }); typeConverter.addTargetMaterialization(materializeUnrealizedCast); typeConverter.addSourceMaterialization(materializeUnrealizedCast); - typeConverter.addArgumentMaterialization(materializeUnrealizedCast); ConversionTarget target(*context); target.addLegalDialect str: @@ -66,9 +66,9 @@ class DocBlockMetadata: @dataclass(frozen=True) class DocTestDirective: mode: str - symbol: str | None = None - compile_kwargs: dict[str, object] | None = None - fixture: str | None = None + symbol: Optional[str] = None + compile_kwargs: Optional[dict[str, object]] = None + fixture: Optional[str] = None @dataclass(frozen=True) @@ -85,12 +85,12 @@ def expect(condition: bool, message: str) -> None: raise AssertionError(message) -def format_doc_context(path: Path, start_line: int, symbol: str | None = None) -> str: +def format_doc_context(path: Path, start_line: int, symbol: Optional[str] = None) -> str: symbol_text = symbol if symbol is not None else "" return f"{path}:{start_line} [symbol={symbol_text}]" -def fail_doc(path: Path, start_line: int, message: str, symbol: str | None = None) -> None: +def fail_doc(path: Path, start_line: int, message: str, symbol: Optional[str] = None) -> None: raise AssertionError(f"{format_doc_context(path, start_line, symbol)}: {message}") @@ -98,7 +98,7 @@ def iter_markdown_files(root: Path) -> Iterable[Path]: yield from sorted(root.glob("*.md")) -def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMetadata | None: +def parse_metadata_line(path: Path, line: str, line_number: int) -> Optional[DocBlockMetadata]: match = META_RE.match(line) if match is None: return None @@ -116,7 +116,7 @@ def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMeta return DocBlockMetadata(kind=kind, body=body, line=line_number, raw=line.rstrip("\n")) -def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlockMetadata | None: +def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> Optional[DocBlockMetadata]: candidate = fence_line - 2 while candidate >= 0 and not lines[candidate].strip(): candidate -= 1 @@ -128,7 +128,7 @@ def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlo return parse_metadata_line(path, line, candidate + 1) -def block_label(block: MarkdownCodeBlock, symbol: str | None = None) -> str: +def block_label(block: MarkdownCodeBlock, symbol: Optional[str] = None) -> str: return format_doc_context(block.path, block.start_line, symbol) @@ -313,9 +313,9 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: def execute_source( source: str, block: MarkdownCodeBlock, - symbol: str | None = None, + symbol: Optional[str] = None, *, - extra_namespace: dict[str, object] | None = None, + extra_namespace: Optional[dict[str, object]] = None, ) -> dict[str, object]: namespace: dict[str, object] = { "__builtins__": __builtins__, @@ -493,7 +493,7 @@ def scan_markdown_file(path: Path) -> MarkdownScanResult: block_language = "" block_start = 0 block_lines: list[str] = [] - metadata: DocBlockMetadata | None = None + metadata: Optional[DocBlockMetadata] = None for index, line in enumerate(lines, start=1): fence_match = FENCE_RE.match(line.rstrip("\n")) diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 2157acdb1a..fd1b391644 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -12,6 +12,7 @@ import re import sys from tempfile import TemporaryDirectory +from typing import Optional from unittest import mock @@ -35,7 +36,7 @@ def expect(condition: bool, message: str) -> None: raise AssertionError(message) -def expect_raises(exc_type, func, message_substring: str | None = None) -> Exception: +def expect_raises(exc_type, func, message_substring: Optional[str] = None) -> Exception: try: func() except exc_type as exc: diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 43552e26c6..908c089d8a 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -11,10 +11,16 @@ import functools import os from pathlib import Path +from typing import Optional from mlir import ir as _ods_ir from . import _pto_ops_gen as _pto_ops_gen +from ._ods_common import ( + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) def _candidate_pto_ext_dirs(): @@ -82,7 +88,7 @@ def _export_generated_symbols(): def get_op_result_or_value(value): - return getattr(_pto_ops_gen, "_get_op_result_or_value")(value) + return _get_op_result_or_value(value) def _export_optional_cext_symbol(name): @@ -169,7 +175,6 @@ def _export_optional_cext_symbol(name): _ptr_type_get_impl = PtrType.get -_ods_get_default_loc_context = getattr(_pto_ops_gen, "_ods_get_default_loc_context") def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): @@ -849,7 +854,7 @@ def _init_inferred(self, source, offsets, sizes, args, loc, ip): sizes, *args = args if args: raise TypeError(f"too many positional arguments: {len(args)}") - source_value = _pto_ops_gen._get_op_result_or_value(source) + source_value = _get_op_result_or_value(source) source_type = source_value.type result = PartitionTensorViewType.get(source_type.rank, source_type.element_type) self._init_explicit(result, source_value, offsets, sizes, (), loc, ip) @@ -866,9 +871,9 @@ def _init_explicit(self, result, source, offsets, sizes, args, loc, ip): if args: raise TypeError(f"too many positional arguments: {len(args)}") operands = [ - _pto_ops_gen._get_op_result_or_value(source), - _pto_ops_gen._get_op_results_or_values(offsets), - _pto_ops_gen._get_op_results_or_values(sizes), + _get_op_result_or_value(source), + _get_op_results_or_values(offsets), + _get_op_results_or_values(sizes), ] op = self.build_generic( attributes={}, @@ -913,12 +918,12 @@ def _is_mask_pattern(value): def _value_type(value): try: - return _pto_ops_gen._get_op_result_or_value(value).type + return _get_op_result_or_value(value).type except Exception: return None def _matches_src_type(value): - src_value = _pto_ops_gen._get_op_result_or_value(src) + src_value = _get_op_result_or_value(src) value_type = _value_type(value) return value_type is not None and value_type == src_value.type @@ -1135,9 +1140,9 @@ class _VKernelCompileError(Exception): @_dataclass class _VKValue: - name: str | None = None - type: _VKernelType | None = None - literal: object | None = None + name: Optional[str] = None + type: Optional[_VKernelType] = None + literal: Optional[object] = None def render_type(self): if self.type is None: diff --git a/test/samples/Qwen3DecodeA5/down_proj_residual.pto b/test/samples/Qwen3DecodeA5/down_proj_residual.pto index b81f7269ba..9e55077eba 100644 --- a/test/samples/Qwen3DecodeA5/down_proj_residual.pto +++ b/test/samples/Qwen3DecodeA5/down_proj_residual.pto @@ -47,44 +47,44 @@ module attributes {pto.target_arch = "a5"} { %23 = arith.cmpi eq, %18, %c0_index : index scf.if %23 { %down_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %down_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_a : !pto.tile_buf) + %down_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_a : !pto.tile_buf) %down_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %down_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%down_acc__tile_l0_a, %down_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%down_acc__tile_l0_a, %down_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_first : !pto.tile_buf) %down_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%down_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%down_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } pto.tpush_to_aiv(%down_acc__tile : !pto.tile_buf) {split = 1} } diff --git a/test/samples/Qwen3DecodeA5/out_proj_residual.pto b/test/samples/Qwen3DecodeA5/out_proj_residual.pto index 991a3b48c9..87c90e3a7a 100644 --- a/test/samples/Qwen3DecodeA5/out_proj_residual.pto +++ b/test/samples/Qwen3DecodeA5/out_proj_residual.pto @@ -45,44 +45,44 @@ module attributes {pto.target_arch = "a5"} { %23 = arith.cmpi eq, %18, %c0_index : index scf.if %23 { %o_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %o_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_a : !pto.tile_buf) + %o_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_a : !pto.tile_buf) %o_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %o_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%o_acc__tile_l0_a, %o_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%o_acc__tile_l0_a, %o_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_first : !pto.tile_buf) %o_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%o_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%o_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } pto.tpush_to_aiv(%o_acc__tile : !pto.tile_buf) {split = 1} } diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto index 770c9f7404..231b7202ff 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto @@ -40,44 +40,44 @@ module attributes {pto.target_arch = "a5"} { %22 = arith.cmpi eq, %17, %c0_index : index scf.if %22 { %q_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %q_acc__tile_l0_a = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_a : !pto.tile_buf) + %q_acc__tile_l0_a = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_a : !pto.tile_buf) %q_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %q_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%q_acc__tile_l0_a, %q_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%q_acc__tile_l0_a, %q_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_first : !pto.tile_buf) %q_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%q_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%q_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } %q_proj__ssa_v0_pview = pto.partition_view %q_proj__ssa_v0_view, offsets = [%c0_index, %16], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%q_acc__tile : !pto.tile_buf) outs(%q_proj__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto index 3b589db9e2..1b419b00a6 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto @@ -33,33 +33,33 @@ module attributes {pto.target_arch = "a5"} { %w_gate__ssa_v0_pview = pto.partition_view %w_gate__ssa_v0_view, offsets = [%c0_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%w_gate__ssa_v0_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%wg_0__tile : !pto.tile_buf) %gate_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %gate_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_a : !pto.tile_buf) + %gate_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_a : !pto.tile_buf) %gate_acc__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %gate_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%gate_acc__tile_l0_a, %gate_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%gate_acc__tile_l0_a, %gate_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_first : !pto.tile_buf) %gate_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%gate_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%gate_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_acc : !pto.tile_buf) %wg_1__tile = pto.alloc_tile addr = %c73728_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %26 = pto.partition_view %w_gate__ssa_v0_view, offsets = [%c128_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%wg_1__tile : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) + pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) scf.for %kb__idx_v0 = %c2_index to %c64_index step %c2_index { %27 = arith.muli %kb__idx_v0, %c128_index : index %28 = arith.muli %kb__idx_v0, %c128_index : index @@ -76,30 +76,30 @@ module attributes {pto.target_arch = "a5"} { %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %33 = pto.partition_view %w_gate__ssa_v0_view, offsets = [%29, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%33 : !pto.partition_tensor_view<128x256xbf16>) outs(%9 : !pto.tile_buf) - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) - %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%17 : !pto.tile_buf) - %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) %21 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) } %gate_group__ssa_v0_pview = pto.partition_view %gate_group__ssa_v0_view, offsets = [%c0_index, %24], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%7 : !pto.tile_buf) outs(%gate_group__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto index 6048201d6e..148552f4ea 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto @@ -33,33 +33,33 @@ module attributes {pto.target_arch = "a5"} { %w_up__ssa_v0_pview = pto.partition_view %w_up__ssa_v0_view, offsets = [%c0_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%w_up__ssa_v0_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%wu_0__tile : !pto.tile_buf) %up_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %up_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_a : !pto.tile_buf) + %up_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_a : !pto.tile_buf) %up_acc__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %up_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%up_acc__tile_l0_a, %up_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%up_acc__tile_l0_a, %up_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_first : !pto.tile_buf) %up_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%up_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%up_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_acc : !pto.tile_buf) %wu_1__tile = pto.alloc_tile addr = %c73728_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %26 = pto.partition_view %w_up__ssa_v0_view, offsets = [%c128_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%wu_1__tile : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) + pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) scf.for %kb__idx_v0 = %c2_index to %c64_index step %c2_index { %27 = arith.muli %kb__idx_v0, %c128_index : index %28 = arith.muli %kb__idx_v0, %c128_index : index @@ -76,30 +76,30 @@ module attributes {pto.target_arch = "a5"} { %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %33 = pto.partition_view %w_up__ssa_v0_view, offsets = [%29, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%33 : !pto.partition_tensor_view<128x256xbf16>) outs(%9 : !pto.tile_buf) - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) - %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%17 : !pto.tile_buf) - %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) %21 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) } %up_group__ssa_v0_pview = pto.partition_view %up_group__ssa_v0_view, offsets = [%c0_index, %24], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%7 : !pto.tile_buf) outs(%up_group__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto index 6b12c37e9c..642f1e5776 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto @@ -54,81 +54,81 @@ module attributes {pto.target_arch = "a5"} { %38 = arith.cmpi eq, %32, %c0_index : index scf.if %38 { %k_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %k_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_a : !pto.tile_buf) + %k_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_a : !pto.tile_buf) %k_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_b : !pto.tile_buf) - %3 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) + %3 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %4 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %k_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%k_acc__tile_l0_a, %k_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%k_acc__tile_l0_a, %k_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_first : !pto.tile_buf) %k_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%k_acc__tile_l0_c_acc, %3, %4 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%k_acc__tile_l0_c_acc, %3, %4 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_acc : !pto.tile_buf) %v_acc__tile_l0_init = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %v_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_a : !pto.tile_buf) + %v_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_a : !pto.tile_buf) %v_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_b : !pto.tile_buf) - %5 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) + %5 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %v_acc__tile_l0_c_first = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%v_acc__tile_l0_a, %v_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%v_acc__tile_l0_a, %v_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_first : !pto.tile_buf) %v_acc__tile_l0_c_acc = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%v_acc__tile_l0_c_acc, %5, %6 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%v_acc__tile_l0_c_acc, %5, %6 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) + %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%8 : !pto.tile_buf) - %9 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%9 : !pto.tile_buf) + %9 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%9 : !pto.tile_buf) %10 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%11, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%11 : !pto.tile_buf) + pto.tmatmul.acc ins(%11, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%11 : !pto.tile_buf) %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%12, %9, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%12 : !pto.tile_buf) - %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) + pto.tmatmul.acc ins(%12, %9, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%12 : !pto.tile_buf) + %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%14 : !pto.tile_buf) - %15 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%15 : !pto.tile_buf) + %15 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%15 : !pto.tile_buf) %16 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%17, %13, %14 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%17 : !pto.tile_buf) + pto.tmatmul.acc ins(%17, %13, %14 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%17 : !pto.tile_buf) %18 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%18, %15, %16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) + pto.tmatmul.acc ins(%18, %15, %16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) } - %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) - %21 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + %21 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) %22 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) %23 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%23, %19, %20 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%23, %19, %20 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%23 : !pto.tile_buf) %24 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%24, %21, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) - %25 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%25 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %25 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%25 : !pto.tile_buf) %26 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%2, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%26 : !pto.tile_buf) - %27 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%27 : !pto.tile_buf) + %27 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%27 : !pto.tile_buf) %28 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%2, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%28 : !pto.tile_buf) %29 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%29, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) %30 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%30, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%30 : !pto.tile_buf) + pto.tmatmul.acc ins(%30, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%30 : !pto.tile_buf) } %k_proj__ssa_v0_pview = pto.partition_view %k_proj__ssa_v0_view, offsets = [%c0_index, %31], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%k_acc__tile : !pto.tile_buf) outs(%k_proj__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto index 97ee54d918..1ce3839319 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto @@ -46,18 +46,18 @@ module attributes {pto.target_arch = "a5"} { %k_cache__rv_v4_dn_view_pview = pto.partition_view %k_cache__rv_v4_dn_view, offsets = [%c0_index, %18], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%k_cache__rv_v4_dn_view_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%k_tile_0__tile : !pto.tile_buf) %raw_scores_0__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %raw_scores_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_a : !pto.tile_buf) + %raw_scores_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_a : !pto.tile_buf) %raw_scores_0__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %raw_scores_0__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%raw_scores_0__tile_l0_a, %raw_scores_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%raw_scores_0__tile_l0_a, %raw_scores_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_first : !pto.tile_buf) %raw_scores_0__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%raw_scores_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%raw_scores_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_acc : !pto.tile_buf) %19 = arith.muli %4, %c256_index : index %20 = arith.muli %sb__idx_v0, %c16_index : index %21 = arith.addi %19, %20 : index @@ -71,18 +71,18 @@ module attributes {pto.target_arch = "a5"} { %26 = pto.partition_view %k_cache__rv_v4_dn_view, offsets = [%c0_index, %25], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%k_tile_1__tile : !pto.tile_buf) %raw_scores_1__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %raw_scores_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_a : !pto.tile_buf) + %raw_scores_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_a : !pto.tile_buf) %raw_scores_1__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %raw_scores_1__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%raw_scores_1__tile_l0_a, %raw_scores_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%raw_scores_1__tile_l0_a, %raw_scores_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_first : !pto.tile_buf) %raw_scores_1__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%raw_scores_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%raw_scores_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_acc : !pto.tile_buf) %28 = arith.muli %6, %c256_index : index %29 = arith.muli %sb__idx_v0, %c16_index : index %30 = arith.addi %28, %29 : index diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto index c8ffed3cde..53617d3d46 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto @@ -37,18 +37,18 @@ module attributes {pto.target_arch = "a5"} { %v_cache__rv_v4_pview = pto.partition_view %v_cache__rv_v4_view, offsets = [%11, %c0_index], sizes = [%c256_index, %c128_index] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xbf16> pto.tload ins(%v_cache__rv_v4_pview : !pto.partition_tensor_view<256x128xbf16>) outs(%v_tile_0__tile : !pto.tile_buf) %oi_tmp_0__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - %oi_tmp_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_a : !pto.tile_buf) + %oi_tmp_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_a : !pto.tile_buf) %oi_tmp_0__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_0__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_0__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c32768_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_0__tile, %c128_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %oi_tmp_0__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul ins(%oi_tmp_0__tile_l0_a, %oi_tmp_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%oi_tmp_0__tile_l0_a, %oi_tmp_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_first : !pto.tile_buf) %oi_tmp_0__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul.acc ins(%oi_tmp_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%oi_tmp_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_acc : !pto.tile_buf) %15 = arith.muli %4, %c256_index : index %16 = arith.muli %sb__idx_v0, %c16_index : index %17 = arith.addi %15, %16 : index @@ -68,18 +68,18 @@ module attributes {pto.target_arch = "a5"} { %26 = pto.partition_view %v_cache__rv_v4_view, offsets = [%21, %c0_index], sizes = [%c256_index, %c128_index] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<256x128xbf16>) outs(%v_tile_1__tile : !pto.tile_buf) %oi_tmp_1__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - %oi_tmp_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_a : !pto.tile_buf) + %oi_tmp_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_a : !pto.tile_buf) %oi_tmp_1__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_1__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_1__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_1__tile, %c128_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %oi_tmp_1__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul ins(%oi_tmp_1__tile_l0_a, %oi_tmp_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%oi_tmp_1__tile_l0_a, %oi_tmp_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_first : !pto.tile_buf) %oi_tmp_1__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul.acc ins(%oi_tmp_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%oi_tmp_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_acc : !pto.tile_buf) %28 = arith.muli %6, %c256_index : index %29 = arith.muli %sb__idx_v0, %c16_index : index %30 = arith.addi %28, %29 : index diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 09a599ad7f..b56a03fdb9 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -735,7 +735,26 @@ def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo return None source = textwrap.dedent("".join(source_lines)) - module = ast.parse(source) + try: + module = ast.parse(source) + except SyntaxError: + try: + module = ast.parse(Path(path).read_text(encoding="utf-8")) + except (OSError, IOError, SyntaxError): + return None + code_line = getattr(py_fn, "__code__", None) + code_first_line = code_line.co_firstlineno if code_line is not None else start_line + for node in ast.walk(module): + if not isinstance(node, ast.FunctionDef) or node.name != py_fn.__name__: + continue + first_line = min( + [node.lineno] + [decorator.lineno for decorator in node.decorator_list] + ) + last_line = getattr(node, "end_lineno", node.lineno) + if first_line <= code_first_line <= last_line: + return _FunctionSourceInfo(path=path, start_line=1, function_def=node) + return None + for node in module.body: if isinstance(node, ast.FunctionDef) and node.name == py_fn.__name__: return _FunctionSourceInfo(path=path, start_line=start_line, function_def=node) diff --git a/tools/ptoas/ObjectEmission.cpp b/tools/ptoas/ObjectEmission.cpp index 09f7604aa5..cd85645cd6 100644 --- a/tools/ptoas/ObjectEmission.cpp +++ b/tools/ptoas/ObjectEmission.cpp @@ -63,7 +63,7 @@ static bool writeTextFile(StringRef path, StringRef content, static void stripUnsupportedBishengAttrs(llvm::Module &module) { for (llvm::Function &function : module) { - // LLVM 19 prints memory effect attributes in textual form like + // LLVM prints memory effect attributes in textual form like // `memory(none)`. beta.1 Bisheng cannot parse that syntax, so remove only // the unsupported memory-effect attribute before serializing the module. function.setAttributes( diff --git a/tools/ptobc/src/mlir_encode.cpp b/tools/ptobc/src/mlir_encode.cpp index bfab87be73..87bcd5f0f0 100644 --- a/tools/ptobc/src/mlir_encode.cpp +++ b/tools/ptobc/src/mlir_encode.cpp @@ -121,6 +121,35 @@ static mlir::DictionaryAttr stripKnownImmediateAttrs( } } +static bool hasDefaultZeroPipeId(llvm::StringRef opName) { + return llvm::StringSwitch(opName) + .Case("pto.aic_initialize_pipe", true) + .Case("pto.aiv_initialize_pipe", true) + .Case("pto.talloc_to_aiv", true) + .Case("pto.talloc_to_aic", true) + .Case("pto.tpush_to_aiv", true) + .Case("pto.tpush_to_aic", true) + .Case("pto.tpop_from_aiv", true) + .Case("pto.tpop_from_aic", true) + .Case("pto.tfree_from_aiv", true) + .Case("pto.tfree_from_aic", true) + .Default(false); +} + +static mlir::DictionaryAttr stripDefaultZeroPipeId(mlir::MLIRContext *ctx, + mlir::DictionaryAttr dict, + mlir::Operation &op) { + if (!dict || dict.empty() || + !hasDefaultZeroPipeId(op.getName().getStringRef())) + return dict; + + auto idAttr = op.getAttrOfType("id"); + if (!idAttr || idAttr.getInt() != 0) + return dict; + + return stripAttrs(ctx, dict, {"id"}); +} + static uint64_t internAttr(PTOBCFile& f, mlir::DictionaryAttr dict) { if (!dict || dict.empty()) return 0; std::string s = printAttrDict(dict); @@ -468,6 +497,8 @@ void Encoder::encodeKnownOpImmediates( mask |= 0x1; if (at.getValidCol()) mask |= 0x2; + if (at.getAddr()) + mask |= 0x4; out.appendU8(mask); imms.push_back(mask); return; @@ -540,7 +571,8 @@ void Encoder::encodeKnownOpOperands( if (imms.empty()) throw std::runtime_error("optmask operands missing immediate"); uint8_t mask = uint8_t(imms.front()); - emitOperands(((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0)); + emitOperands(((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0) + + ((mask & 0x4) ? 1 : 0)); return; } default: @@ -557,6 +589,7 @@ void Encoder::encodeKnownOp(mlir::Operation &op, Buffer &out, out.appendU16LE(variantInfo.opcode); mlir::DictionaryAttr dict = op.getAttrDictionary(); dict = stripKnownImmediateAttrs(op.getContext(), dict, info); + dict = stripDefaultZeroPipeId(op.getContext(), dict, op); writeULEB128(internAttr(file, dict), out.bytes); if (info.has_variant_u8) diff --git a/tools/ptobc/src/ptobc_decode_print.cpp b/tools/ptobc/src/ptobc_decode_print.cpp index 3c13ae31b9..cdb0280902 100644 --- a/tools/ptobc/src/ptobc_decode_print.cpp +++ b/tools/ptobc/src/ptobc_decode_print.cpp @@ -487,7 +487,8 @@ readKnownOperandIds(BuildCtx &bc, Reader &r, uint16_t opcode, uint8_t variant, case 0x04: return reorderLegacyIndexedTscatter( readValueIds(r, ((imms.optMask & 0x1) ? 1 : 0) + - ((imms.optMask & 0x2) ? 1 : 0))); + ((imms.optMask & 0x2) ? 1 : 0) + + ((imms.optMask & 0x4) ? 1 : 0))); default: (void)bc; throw std::runtime_error("unknown operand_mode"); diff --git a/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto new file mode 100644 index 0000000000..51b4ec2653 --- /dev/null +++ b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @recent_mx_ops_v0() { + %a = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %dst_acc = pto.alloc_tile : !pto.tile_buf + %dst_bias = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_acc : !pto.tile_buf) + pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_bias : !pto.tile_buf) + return + } +} diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index ecbf78fa89..af4254b1ff 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -42,18 +42,6 @@ module { pto.trsqrt ins(%src : !pto.tile_buf) outs(%rs_dst0 : !pto.tile_buf) pto.trsqrt ins(%src, %rs_tmp : !pto.tile_buf, !pto.tile_buf) outs(%rs_dst1 : !pto.tile_buf) pto.tpartmul ins(%part0, %part1 : !pto.tile_buf, !pto.tile_buf) outs(%partdst : !pto.tile_buf) - %a = pto.alloc_tile : !pto.tile_buf - %a_scale = pto.alloc_tile : !pto.tile_buf - %b = pto.alloc_tile : !pto.tile_buf - %b_scale = pto.alloc_tile : !pto.tile_buf - %c_in = pto.alloc_tile : !pto.tile_buf - %bias_mx = pto.alloc_tile : !pto.tile_buf - %dst_mx = pto.alloc_tile : !pto.tile_buf - %dst_mx_acc = pto.alloc_tile : !pto.tile_buf - %dst_mx_bias = pto.alloc_tile : !pto.tile_buf - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx : !pto.tile_buf) - pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_acc : !pto.tile_buf) - pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias_mx : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_bias : !pto.tile_buf) return } } diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index b8c302bf53..8517b4174b 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -22,14 +22,19 @@ if [[ -z "${TESTDATA_DIR}" ]]; then fi IN="${TESTDATA_DIR}/recent_ops_v0_roundtrip.pto" +MX_IN="${TESTDATA_DIR}/recent_mx_ops_v0_roundtrip.pto" OUT_DIR=${OUT_DIR:-"${PWD}/ptobc_recent_ops_out"} mkdir -p "${OUT_DIR}" BC="${OUT_DIR}/recent_ops_v0_roundtrip.ptobc" ROUNDTRIP="${OUT_DIR}/recent_ops_v0_roundtrip.roundtrip.pto" +MX_BC="${OUT_DIR}/recent_mx_ops_v0_roundtrip.ptobc" +MX_ROUNDTRIP="${OUT_DIR}/recent_mx_ops_v0_roundtrip.roundtrip.pto" "${PTOBC_BIN}" encode "${IN}" -o "${BC}" "${PTOBC_BIN}" decode "${BC}" -o "${ROUNDTRIP}" +"${PTOBC_BIN}" encode "${MX_IN}" -o "${MX_BC}" +"${PTOBC_BIN}" decode "${MX_BC}" -o "${MX_ROUNDTRIP}" grep -F "pto.subview " "${ROUNDTRIP}" >/dev/null grep -F "pto.tprint ins(" "${ROUNDTRIP}" >/dev/null @@ -41,6 +46,6 @@ grep -F "pto.trowexpandmul ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trsqrt ins(" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.trsqrt ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null grep -F "pto.tpartmul ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx.acc ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx.bias ins(" "${ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx ins(" "${MX_ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.acc ins(" "${MX_ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.bias ins(" "${MX_ROUNDTRIP}" >/dev/null From 0655206d6821da3679367560dbc4a7935e0c6c31 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 14:57:07 +0800 Subject: [PATCH 02/51] fix: complete LLVM 21 VPTO type migration --- lib/PTO/IR/VPTO.cpp | 5 +- lib/PTO/Transforms/ExpandTileOp.cpp | 5 +- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 47 ++++++++++--------- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 47 ++++++++++--------- 4 files changed, 53 insertions(+), 51 deletions(-) diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 6c30ccca71..908c604a16 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1420,10 +1420,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 3045901099..2dac9d936e 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -26,6 +26,7 @@ // #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -199,8 +200,8 @@ static std::string getDtypeString(Type elemTy) { if (elemTy.isF32()) return "f32"; if (elemTy.isF16()) return "f16"; if (elemTy.isBF16()) return "bf16"; - if (elemTy.isFloat8E4M3FN()) return "f8e4m3"; - if (elemTy.isFloat8E5M2()) return "f8e5m2"; + if (pto::isPTOFloat8E4M3LikeType(elemTy)) return "f8e4m3"; + if (pto::isPTOFloat8E5M2LikeType(elemTy)) return "f8e5m2"; if (isa(elemTy)) return "hif8"; if (isa(elemTy)) return "f4e1m2x2"; if (isa(elemTy)) return "f4e2m1x2"; diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 953b75d88d..fc57a57622 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -86,6 +86,18 @@ static Type getLLVMCompatibleVectorType(ArrayRef shape, return VectorType::get(shape, elementType, scalableDims); } +static bool isLowpPayloadABIElementType(Type type) { + return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || + pto::isPTOFloat4PackedType(type); +} + +static Type getLowpPayloadABIElementType(Type elementType, + MLIRContext *context) { + if (!isLowpPayloadABIElementType(elementType)) + return {}; + return IntegerType::get(context, 8); +} + static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( @@ -117,10 +129,6 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, return builder.getI16Type(); if (pto::isPTOLowPrecisionType(type)) return builder.getI8Type(); - if (isa(type)) - return builder.getI8Type(); if (auto vecType = dyn_cast(type)) { Type normalizedElement = @@ -132,23 +140,17 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, vecType.getScalableDims()); } - if (auto vecType = dyn_cast(type)) { - Type normalizedElement = - normalizeGEPElementTypeForLLVMLowering(vecType.getElementType(), - builder); - if (normalizedElement == vecType.getElementType()) - return normalizePayloadTypeForLLVMLowering(type, builder); - return LLVM::LLVMFixedVectorType::get(normalizedElement, - vecType.getNumElements()); - } - return normalizePayloadTypeForLLVMLowering(type, builder); } static Type convertVPTOType(Type type, Builder &builder) { if (auto vecType = dyn_cast(type)) { - Type elementType = - normalizePayloadTypeForLLVMLowering(vecType.getElementType(), builder); + Type sourceElementType = vecType.getElementType(); + Type elementType = getLowpPayloadABIElementType(sourceElementType, + builder.getContext()); + if (!elementType) + elementType = normalizePayloadTypeForLLVMLowering(sourceElementType, + builder); return getLLVMCompatibleVectorType({vecType.getElementCount()}, elementType); } @@ -546,8 +548,7 @@ static std::string getMemoryElementTypeFragment(Type type) { } static bool isLowpPayloadElementType(Type type) { - return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || - pto::isPTOFloat4PackedType(type); + return isLowpPayloadABIElementType(type); } struct LowpPayloadABI { @@ -557,9 +558,10 @@ struct LowpPayloadABI { static std::optional getLowpPayloadABI(Type elementType, MLIRContext *context) { - if (!isLowpPayloadElementType(elementType)) + Type carrierElementType = getLowpPayloadABIElementType(elementType, context); + if (!carrierElementType) return std::nullopt; - return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; + return LowpPayloadABI{carrierElementType, "u8"}; } static Type getLowpPayloadCarrierType(Type vectorLikeType, @@ -1063,10 +1065,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index f41169c796..bd432c62b4 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -87,6 +87,18 @@ static Type getLLVMCompatibleVectorType(ArrayRef shape, return VectorType::get(shape, elementType, scalableDims); } +static bool isLowpPayloadABIElementType(Type type) { + return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || + pto::isPTOFloat4PackedType(type); +} + +static Type getLowpPayloadABIElementType(Type elementType, + MLIRContext *context) { + if (!isLowpPayloadABIElementType(elementType)) + return {}; + return IntegerType::get(context, 8); +} + static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( @@ -118,10 +130,6 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, return builder.getI16Type(); if (pto::isPTOLowPrecisionType(type)) return builder.getI8Type(); - if (isa(type)) - return builder.getI8Type(); if (auto vecType = dyn_cast(type)) { Type normalizedElement = @@ -133,23 +141,17 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, vecType.getScalableDims()); } - if (auto vecType = dyn_cast(type)) { - Type normalizedElement = - normalizeGEPElementTypeForLLVMLowering(vecType.getElementType(), - builder); - if (normalizedElement == vecType.getElementType()) - return normalizePayloadTypeForLLVMLowering(type, builder); - return LLVM::LLVMFixedVectorType::get(normalizedElement, - vecType.getNumElements()); - } - return normalizePayloadTypeForLLVMLowering(type, builder); } static Type convertVPTOType(Type type, Builder &builder) { if (auto vecType = dyn_cast(type)) { - Type elementType = - normalizePayloadTypeForLLVMLowering(vecType.getElementType(), builder); + Type sourceElementType = vecType.getElementType(); + Type elementType = getLowpPayloadABIElementType(sourceElementType, + builder.getContext()); + if (!elementType) + elementType = normalizePayloadTypeForLLVMLowering(sourceElementType, + builder); return getLLVMCompatibleVectorType({vecType.getElementCount()}, elementType); } @@ -501,8 +503,7 @@ static std::string getMemoryElementTypeFragment(Type type) { } static bool isLowpPayloadElementType(Type type) { - return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || - pto::isPTOFloat4PackedType(type); + return isLowpPayloadABIElementType(type); } struct LowpPayloadABI { @@ -512,9 +513,10 @@ struct LowpPayloadABI { static std::optional getLowpPayloadABI(Type elementType, MLIRContext *context) { - if (!isLowpPayloadElementType(elementType)) + Type carrierElementType = getLowpPayloadABIElementType(elementType, context); + if (!carrierElementType) return std::nullopt; - return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; + return LowpPayloadABI{carrierElementType, "u8"}; } static Type getLowpPayloadCarrierType(Type vectorLikeType, @@ -1031,10 +1033,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; From 5a71cf2c06d51922a3f1985704bf40197a0dbc22 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 16:24:45 +0800 Subject: [PATCH 03/51] fix: adapt EmitC array lowering for LLVM 21 --- lib/PTO/Transforms/PTOToEmitC.cpp | 34 ++++++++----------- test/lit/pto/async_put_get_emitc.pto | 11 +++--- ...edge_nested_same_pipe_prune_regression.pto | 2 +- test/lit/pto/eventid_array_dyn_sync.pto | 10 +++--- test/lit/pto/eventid_array_get_set_get.pto | 11 +++--- test/lit/pto/eventid_array_no_cse.pto | 6 ++-- .../lit/pto/issue428_cube_sync_regression.pto | 4 +-- ...p_if_else_loop_carried_sync_regression.pto | 4 +-- ..._else_loop_carried_sync_regression_gss.pto | 4 +-- ..._nested_loop_same_pipe_pair_regression.pto | 4 +-- ...ted_loop_same_pipe_pair_regression_gss.pto | 2 +- ...ssue533_loop_zero_trip_sync_regression.pto | 2 +- ...533_loop_zero_trip_sync_regression_gss.pto | 2 +- .../issue556_tpop_live_values_no_alias.pto | 10 +++--- ...ue564_k_loop_mte1_mte2_wait_regression.pto | 2 +- ...6_a5_tmov_treshape_dynamic_valid_shape.pto | 6 ++-- ...ov_treshape_dynamic_valid_shape_level2.pto | 6 ++-- .../pto/issue713_local_array_get_snapshot.pto | 9 ++--- test/lit/pto/local_array_1d_emitc.pto | 3 +- test/lit/pto/local_array_2d_emitc.pto | 3 +- test/lit/pto/local_array_get_rvalue_emitc.pto | 6 ++-- .../lit/pto/syncfinder_zero_loop_if_probe.pto | 4 +-- .../pto/syncfinder_zero_loop_if_probe_gss.pto | 4 +-- test/lit/pto/tassign_level3_loop_rebind.pto | 3 +- .../pto/tassign_level3_loop_rebind_gss.pto | 3 +- test/lit/pto/tci_i16_emitc.pto | 4 +-- test/lit/pto/tci_ui32_emitc.pto | 4 +-- test/lit/pto/tprint_alloc_tile_no_rebind.pto | 5 +-- .../tpush_tpop_globaltensor_frontend_a3.pto | 6 ++-- .../pto/treshape_static_valid_shape_emitc.pto | 6 ++-- .../mark_last_use_slot_mask_level2.pto | 10 +++--- 31 files changed, 99 insertions(+), 91 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index c0a99f13ba..007ce9d69f 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -7494,9 +7494,10 @@ struct PTOEventIdArrayGetToEmitC return rewriter.notifyMatchFailure(op, "failed to map eventid_array get result type"); - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); + auto subscript = rewriter.create( + op.getLoc(), emitc::LValueType::get(resultTy), array, + ValueRange{index}); + rewriter.replaceOpWithNewOp(op, resultTy, subscript); return success(); } }; @@ -7513,9 +7514,12 @@ struct PTOEventIdArraySetToEmitC Value index = peelUnrealized(adaptor.getIndex()); Value value = peelUnrealized(adaptor.getValue()); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + Value slot = rewriter + .create( + op.getLoc(), emitc::LValueType::get(value.getType()), + array, ValueRange{index}) + .getResult(); + rewriter.create(op.getLoc(), slot, value); rewriter.eraseOp(op); return success(); } @@ -7572,16 +7576,8 @@ struct PTOLocalArrayGetToEmitC indices.push_back(peelUnrealized(index)); auto sub = rewriter.create( - op.getLoc(), resultTy, array, indices); - auto snapshot = - rewriter - .create( - op.getLoc(), getEmitCVariableResultType(resultTy), - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.create(op.getLoc(), snapshot, sub.getResult()); - rewriter.replaceOp( - op, loadEmitCVariableIfNeeded(rewriter, op.getLoc(), snapshot)); + op.getLoc(), emitc::LValueType::get(resultTy), array, indices); + rewriter.replaceOpWithNewOp(op, resultTy, sub); return success(); } }; @@ -7601,9 +7597,9 @@ struct PTOLocalArraySetToEmitC Type elemTy = value.getType(); Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) + .create( + op.getLoc(), emitc::LValueType::get(elemTy), + adaptor.getArray(), adaptor.getIndices()) .getResult(); rewriter.create(op.getLoc(), slot, value); rewriter.eraseOp(op); diff --git a/test/lit/pto/async_put_get_emitc.pto b/test/lit/pto/async_put_get_emitc.pto index b191dbd0bb..30998f7c3e 100644 --- a/test/lit/pto/async_put_get_emitc.pto +++ b/test/lit/pto/async_put_get_emitc.pto @@ -16,10 +16,13 @@ module { // A3-LABEL: AICORE void async_put_get( // A3: Tile [[SCRATCH:v[0-9]+]]; -// A3: TASSIGN([[SCRATCH]], [[SCRATCH_ADDR:v[0-9]+]]); +// A3: Tile [[SCRATCH_COPY:v[0-9]+]] = [[SCRATCH]]; +// A3: TASSIGN([[SCRATCH_COPY]], [[SCRATCH_ADDR:v[0-9]+]]); // A3: pto::comm::AsyncSession [[SESSION:v[0-9]+]]; +// A3: pto::comm::AsyncSession [[SESSION_COPY:v[0-9]+]] = [[SESSION]]; // A3: pto::comm::sdma::SdmaBaseConfig [[CFG:v[0-9]+]] = {32768ULL, 0ULL, 1u}; -// A3: pto::comm::BuildAsyncSession([[SCRATCH]], {{.*}}, [[SESSION]], {{.*}}, [[CFG]], {{.*}}); +// A3: pto::comm::sdma::SdmaBaseConfig [[CFG_COPY:v[0-9]+]] = [[CFG]]; +// A3: pto::comm::BuildAsyncSession([[SCRATCH_COPY]], {{.*}}, [[SESSION_COPY]], {{.*}}, [[CFG_COPY]], {{.*}}); // A3: using [[SHAPETY:.*]] = pto::Shape<1, 1, 1, 1, 128>; // A3: using [[STRIDETY:.*]] = pto::Stride<128, 128, 128, 128, 1>; // A3: constexpr pto::Layout [[LAYOUT:.*]] = pto::Layout::ND; @@ -32,5 +35,5 @@ module { // A3: [[GLTNSRTY]] [[GT1:v[0-9]+]] = [[GLTNSRTY]]({{.*}}, [[SHAPE1]], [[STRIDE1]]); // A3: pto::comm::AsyncEvent [[PUT_EVT:v[0-9]+]] = pto::comm::TPUT_ASYNC( // A3: pto::comm::AsyncEvent [[GET_EVT:v[0-9]+]] = pto::comm::TGET_ASYNC( -// A3: bool [[PUT_DONE:v[0-9]+]] = [[PUT_EVT]].Wait([[SESSION]]); -// A3: bool [[GET_DONE:v[0-9]+]] = [[GET_EVT]].Test([[SESSION]]); +// A3: bool [[PUT_DONE:v[0-9]+]] = [[PUT_EVT]].Wait([[SESSION_COPY]]); +// A3: bool [[GET_DONE:v[0-9]+]] = [[GET_EVT]].Test([[SESSION_COPY]]); diff --git a/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto b/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto index eccfe3596c..e54c8f9e18 100644 --- a/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto +++ b/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto @@ -5,7 +5,7 @@ // the wider same-pipe pair can be removed before event-id allocation. // // CHECK-LABEL: AICORE void backedge_nested_same_pipe_prune() -// CHECK: for (size_t {{v[0-9]+}} = +// CHECK: for (int64_t {{[ij][0-9]+}} = // CHECK-NEXT: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[INNER:[0-9]+]]); // CHECK-NEXT: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK-NEXT: TMOV({{v[0-9]+}}, {{v[0-9]+}}); diff --git a/test/lit/pto/eventid_array_dyn_sync.pto b/test/lit/pto/eventid_array_dyn_sync.pto index b63b3e0d4b..e969b767f2 100644 --- a/test/lit/pto/eventid_array_dyn_sync.pto +++ b/test/lit/pto/eventid_array_dyn_sync.pto @@ -15,9 +15,11 @@ module { // CHECK-LABEL: AICORE void eventid_array_dyn_sync() { // CHECK: const int64_t {{v[0-9]+}} = 0; -// CHECK: PTOAS_EventIdArray<4> {{v[0-9]+}}; -// CHECK: {{v[0-9]+}}[{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: event_t {{v[0-9]+}} = (event_t) {{v[0-9]+}}[{{v[0-9]+}}]; +// CHECK: PTOAS_EventIdArray<4> [[ARR:v[0-9]+]]; +// CHECK: PTOAS_EventIdArray<4> [[ARR_VAL:v[0-9]+]] = [[ARR]]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[EID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: event_t {{v[0-9]+}} = (event_t) [[EID]]; // CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, {{v[0-9]+}}); -// CHECK: event_t {{v[0-9]+}} = (event_t) {{v[0-9]+}}[{{v[0-9]+}}]; +// CHECK: event_t {{v[0-9]+}} = (event_t) [[EID]]; // CHECK: wait_flag(PIPE_MTE2, PIPE_MTE3, {{v[0-9]+}}); diff --git a/test/lit/pto/eventid_array_get_set_get.pto b/test/lit/pto/eventid_array_get_set_get.pto index 6de754b085..7578c7a403 100644 --- a/test/lit/pto/eventid_array_get_set_get.pto +++ b/test/lit/pto/eventid_array_get_set_get.pto @@ -18,9 +18,12 @@ module { // CHECK-LABEL: AICORE void eventid_array_get_set_get() { // CHECK: PTOAS_EventIdArray<4> [[ARR:v[0-9]+]]; -// CHECK: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: event_t [[FIRST:v[0-9]+]] = (event_t) [[ARR]][{{v[0-9]+}}]; +// CHECK: PTOAS_EventIdArray<4> [[ARR_VAL:v[0-9]+]] = [[ARR]]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[FIRST_ID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[SECOND_ID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: event_t [[FIRST:v[0-9]+]] = (event_t) [[FIRST_ID]]; // CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, [[FIRST]]); -// CHECK: event_t [[SECOND:v[0-9]+]] = (event_t) [[ARR]][{{v[0-9]+}}]; +// CHECK: event_t [[SECOND:v[0-9]+]] = (event_t) [[SECOND_ID]]; // CHECK: wait_flag(PIPE_MTE2, PIPE_MTE3, [[SECOND]]); diff --git a/test/lit/pto/eventid_array_no_cse.pto b/test/lit/pto/eventid_array_no_cse.pto index a589b5da00..8d18b44176 100644 --- a/test/lit/pto/eventid_array_no_cse.pto +++ b/test/lit/pto/eventid_array_no_cse.pto @@ -19,6 +19,8 @@ module { // CHECK-LABEL: AICORE void eventid_array_no_cse() { // CHECK: PTOAS_EventIdArray<4> [[ARR0:v[0-9]+]]; +// CHECK: PTOAS_EventIdArray<4> [[ARR0_VAL:v[0-9]+]] = [[ARR0]]; // CHECK: PTOAS_EventIdArray<4> [[ARR1:v[0-9]+]]; -// CHECK: [[ARR0]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: [[ARR1]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: PTOAS_EventIdArray<4> [[ARR1_VAL:v[0-9]+]] = [[ARR1]]; +// CHECK: [[ARR0_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: [[ARR1_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; diff --git a/test/lit/pto/issue428_cube_sync_regression.pto b/test/lit/pto/issue428_cube_sync_regression.pto index e5ed4f7917..335c2c9e26 100644 --- a/test/lit/pto/issue428_cube_sync_regression.pto +++ b/test/lit/pto/issue428_cube_sync_regression.pto @@ -5,7 +5,7 @@ // - Preserve the tail drain before the auto tail barrier helper. // // CHECK-LABEL: tri_inv_block2x2_fp16( -// CHECK: for (size_t [[LOOP0:v[0-9]+]] = +// CHECK: for (int64_t [[LOOP0:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); @@ -37,7 +37,7 @@ // CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); // CHECK: set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); // CHECK: } -// CHECK: for (size_t [[LOOP1:v[0-9]+]] = +// CHECK: for (int64_t [[LOOP1:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); diff --git a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto index cd8c9528b0..7748f8cba8 100644 --- a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto +++ b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto @@ -16,7 +16,7 @@ // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); -// CHECK: for (size_t [[IV:v[0-9]+]] = +// CHECK: for (int64_t [[IV:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); @@ -26,7 +26,7 @@ // CHECK: set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); // CHECK: wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); -// CHECK: if ((int64_t) [[IV]] == +// CHECK: if ([[IV]] == // CHECK: } else { // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); // CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); diff --git a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto index 019a12654b..323d7de1b0 100644 --- a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto +++ b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto @@ -5,8 +5,8 @@ // // CHECK-DAG: pipe_barrier(PIPE_ALL) // CHECK-LABEL: AICORE void loop_if_else_loop_carried_sync() -// CHECK: for (size_t -// CHECK: if ((int64_t) +// CHECK: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: } else { // CHECK: #endif // __DAV_CUBE__ diff --git a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto index b673740f52..e7f74de69c 100644 --- a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto +++ b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto @@ -5,9 +5,9 @@ // - Inner-loop and outer-loop event chains must both keep their set/wait handshake. // // CHECK-LABEL: AICORE void nested_loop_same_pipe_pair() -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT:[0-9]+]]); -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN:[0-9]+]]); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN]]); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT]]); diff --git a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto index c6aef83ce1..2139d89bf6 100644 --- a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto +++ b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto @@ -8,7 +8,7 @@ // - Inner-loop and outer-loop event chains must both keep their set/wait handshake. // // CHECK-LABEL: AICORE void nested_loop_same_pipe_pair() -// CHECK-COUNT-2: for (size_t +// CHECK-COUNT-2: for (int64_t // CHECK: TMATMUL( // CHECK: #endif // __DAV_CUBE__ diff --git a/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto b/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto index 7ad2387278..58e1c7b739 100644 --- a/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto +++ b/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto @@ -4,7 +4,7 @@ // a post-loop vector consumer must wait on MTE2->V before TROWEXPANDDIV. // // CHECK-LABEL: AICORE void qwen3_scope2_incore_5 -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: } // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[POST:[0-9]+]]); // CHECK: TROWEXPANDDIV diff --git a/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto b/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto index b8d73b19dc..56e75b410b 100644 --- a/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto +++ b/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto @@ -4,7 +4,7 @@ // a post-loop vector consumer must wait on MTE2->V before TROWEXPANDDIV. // // CHECK-LABEL: AICORE void qwen3_scope2_incore_5 -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: } // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[POST:[0-9]+]]); // CHECK: TROWEXPANDDIV diff --git a/test/lit/pto/issue556_tpop_live_values_no_alias.pto b/test/lit/pto/issue556_tpop_live_values_no_alias.pto index de1734a5c2..ddd179e75b 100644 --- a/test/lit/pto/issue556_tpop_live_values_no_alias.pto +++ b/test/lit/pto/issue556_tpop_live_values_no_alias.pto @@ -58,8 +58,10 @@ module { // CHECK-LABEL: AICORE void issue556_tpop_live_values_no_alias( // CHECK: Tile [[POP0:v[0-9]+]]; -// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP0]]); +// CHECK: Tile [[POP0_COPY:v[0-9]+]] = [[POP0]]; +// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP0_COPY]]); // CHECK: Tile [[POP1:v[0-9]+]]; -// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP1]]); -// CHECK: TMOV({{v[0-9]+}}, [[POP0]]); -// CHECK: TMOV({{v[0-9]+}}, [[POP1]]); +// CHECK: Tile [[POP1_COPY:v[0-9]+]] = [[POP1]]; +// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP1_COPY]]); +// CHECK: TMOV({{v[0-9]+}}, [[POP0_COPY]]); +// CHECK: TMOV({{v[0-9]+}}, [[POP1_COPY]]); diff --git a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto index 9c5b00dd1f..3690ca6114 100644 --- a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto +++ b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto @@ -10,7 +10,7 @@ // CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE0:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY]]); -// CHECK-NEXT: for (size_t +// CHECK-NEXT: for (int64_t // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0:[0-9]+]]); // CHECK-NEXT: TLOAD( // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1:[0-9]+]]); diff --git a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto index 3696e084b6..a14b12d9a0 100644 --- a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto +++ b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto @@ -28,9 +28,11 @@ module attributes {pto.target_arch = "a5"} { // CHECK: Tile [[SRC_ORIG:v[0-9]+]] = Tile // CHECK: Tile [[DST_ORIG:v[0-9]+]] = Tile -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK-NEXT: TRESHAPE([[SRC]], [[SRC_ORIG]]); -// CHECK: Tile [[DST:v[0-9]+]]; +// CHECK: Tile [[DST_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[DST:v[0-9]+]] = [[DST_STORAGE]]; // CHECK-NEXT: TRESHAPE([[DST]], [[DST_ORIG]]); // CHECK: int64_t [[SRC_ROW:v[0-9]+]] = [[SRC_ORIG]].GetValidRow(); // CHECK-NEXT: int64_t [[SRC_COL:v[0-9]+]] = [[SRC_ORIG]].GetValidCol(); diff --git a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto index 140f06b665..e3c3864d05 100644 --- a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto +++ b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto @@ -27,9 +27,11 @@ module attributes {pto.target_arch = "a5"} { // CHECK-LABEL: AICORE void a5_tmov_treshape_dynamic_valid_shape_level2() // CHECK: Tile [[SRC_ORIG:v[0-9]+]] // CHECK: Tile [[DST_ORIG:v[0-9]+]] -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK-NEXT: TRESHAPE([[SRC]], [[SRC_ORIG]]); -// CHECK: Tile [[DST:v[0-9]+]]; +// CHECK: Tile [[DST_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[DST:v[0-9]+]] = [[DST_STORAGE]]; // CHECK-NEXT: TRESHAPE([[DST]], [[DST_ORIG]]); // CHECK: int64_t [[SRC_ROW:v[0-9]+]] = [[SRC_ORIG]].GetValidRow(); // CHECK-NEXT: int64_t [[SRC_COL:v[0-9]+]] = [[SRC_ORIG]].GetValidCol(); diff --git a/test/lit/pto/issue713_local_array_get_snapshot.pto b/test/lit/pto/issue713_local_array_get_snapshot.pto index 28ffc34b90..c549c233d7 100644 --- a/test/lit/pto/issue713_local_array_get_snapshot.pto +++ b/test/lit/pto/issue713_local_array_get_snapshot.pto @@ -37,11 +37,8 @@ module { // CPP: int32_t [[ARR:v[0-9]+]][2]; // CPP: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; // CPP: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CPP: int32_t [[CUR:v[0-9]+]]; -// CPP-NEXT: [[CUR]] = [[ARR]][{{v[0-9]+}}]; -// CPP-NEXT: uint32_t [[NEW_U:v[0-9]+]] = (uint32_t) [[CUR]]; -// CPP-NEXT: [[ARR]][{{v[0-9]+}}] = {{.*}}[[NEW_U]]{{.*}}; +// CPP: int32_t [[CUR:v[0-9]+]] = [[ARR]][{{v[0-9]+}}]; +// CPP-NEXT: [[ARR]][{{v[0-9]+}}] = {{.*}}(uint32_t) [[CUR]]{{.*}}; // CPP-NOT: [[ARR]][ -// CPP: uint32_t [[SUM_U:v[0-9]+]] = (uint32_t) [[CUR]]; -// CPP-NEXT: int32_t [[SUM:v[0-9]+]] = {{.*}}[[SUM_U]]{{.*}}; +// CPP: int32_t [[SUM:v[0-9]+]] = {{.*}}(uint32_t) [[CUR]]{{.*}}; // CPP-NEXT: [[OUT]][{{v[0-9]+}}] = [[SUM]]; diff --git a/test/lit/pto/local_array_1d_emitc.pto b/test/lit/pto/local_array_1d_emitc.pto index d38c2ca437..c90943f247 100644 --- a/test/lit/pto/local_array_1d_emitc.pto +++ b/test/lit/pto/local_array_1d_emitc.pto @@ -25,6 +25,5 @@ module { // CHECK-LABEL: local_array_1d // CHECK: int32_t [[A:v[0-9]+]][16]; // CHECK: [[A]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: int32_t [[R:v[0-9]+]]; -// CHECK: [[R]] = [[A]][{{v[0-9]+}}]; +// CHECK: int32_t [[R:v[0-9]+]] = [[A]][{{v[0-9]+}}]; // CHECK: [[A]][{{v[0-9]+}}] = [[R]]; diff --git a/test/lit/pto/local_array_2d_emitc.pto b/test/lit/pto/local_array_2d_emitc.pto index 222f2d71de..584ed6efa7 100644 --- a/test/lit/pto/local_array_2d_emitc.pto +++ b/test/lit/pto/local_array_2d_emitc.pto @@ -25,6 +25,5 @@ module { // CHECK-LABEL: local_array_2d // CHECK: float [[M:v[0-9]+]][8][8]; // CHECK: [[M]][{{v[0-9]+}}][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: float [[R:v[0-9]+]]; -// CHECK: [[R]] = [[M]][{{v[0-9]+}}][{{v[0-9]+}}]; +// CHECK: float [[R:v[0-9]+]] = [[M]][{{v[0-9]+}}][{{v[0-9]+}}]; // CHECK: [[M]][{{v[0-9]+}}][{{v[0-9]+}}] = [[R]]; diff --git a/test/lit/pto/local_array_get_rvalue_emitc.pto b/test/lit/pto/local_array_get_rvalue_emitc.pto index 91140eec83..132e14ca47 100644 --- a/test/lit/pto/local_array_get_rvalue_emitc.pto +++ b/test/lit/pto/local_array_get_rvalue_emitc.pto @@ -24,7 +24,5 @@ module { // CHECK-LABEL: local_array_get_rvalue // CHECK: int32_t [[A:v[0-9]+]][16]; -// CHECK: int32_t [[R:v[0-9]+]]; -// CHECK: [[R]] = [[A]][{{v[0-9]+}}]; -// CHECK: uint32_t [[R_U:v[0-9]+]] = (uint32_t) [[R]]; -// CHECK: [[A]][{{v[0-9]+}}] = {{.*}}[[R_U]]{{.*}}; +// CHECK: int32_t [[R:v[0-9]+]] = [[A]][{{v[0-9]+}}]; +// CHECK: [[A]][{{v[0-9]+}}] = {{.*}}(uint32_t) [[R]]{{.*}}; diff --git a/test/lit/pto/syncfinder_zero_loop_if_probe.pto b/test/lit/pto/syncfinder_zero_loop_if_probe.pto index 9f04615a22..e68216fa35 100644 --- a/test/lit/pto/syncfinder_zero_loop_if_probe.pto +++ b/test/lit/pto/syncfinder_zero_loop_if_probe.pto @@ -8,8 +8,8 @@ // CHECK: TLOAD // CHECK: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP]]); -// CHECK-NEXT: for (size_t -// CHECK: if ((int64_t) +// CHECK-NEXT: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: TMOV // CHECK: } // CHECK: } diff --git a/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto b/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto index 9d5fe9dd91..f174d48893 100644 --- a/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto +++ b/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto @@ -8,8 +8,8 @@ // CHECK: TLOAD // CHECK: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP]]); -// CHECK-NEXT: for (size_t -// CHECK: if ((int64_t) +// CHECK-NEXT: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: TMOV // CHECK: } // CHECK: } diff --git a/test/lit/pto/tassign_level3_loop_rebind.pto b/test/lit/pto/tassign_level3_loop_rebind.pto index a8d4a732b7..4246386de7 100644 --- a/test/lit/pto/tassign_level3_loop_rebind.pto +++ b/test/lit/pto/tassign_level3_loop_rebind.pto @@ -35,7 +35,8 @@ module { } // CHECK-LABEL: __global__ AICORE void tassign_loop_rebind() { -// CHECK: Tile [[T:v[0-9]+]]; +// CHECK: Tile [[T_STORAGE:v[0-9]+]]; +// CHECK: Tile [[T:v[0-9]+]] = [[T_STORAGE]]; // CHECK: for ( // CHECK: TASSIGN([[T]], // CHECK: TPRINT{{(<.*>)?}}([[T]]); diff --git a/test/lit/pto/tassign_level3_loop_rebind_gss.pto b/test/lit/pto/tassign_level3_loop_rebind_gss.pto index bc4d2c9480..7a5f5affa1 100644 --- a/test/lit/pto/tassign_level3_loop_rebind_gss.pto +++ b/test/lit/pto/tassign_level3_loop_rebind_gss.pto @@ -35,7 +35,8 @@ module { } // CHECK-LABEL: __global__ AICORE void tassign_loop_rebind() { -// CHECK: Tile [[T:v[0-9]+]]; +// CHECK: Tile [[T_STORAGE:v[0-9]+]]; +// CHECK: Tile [[T:v[0-9]+]] = [[T_STORAGE]]; // CHECK: for ( // CHECK: TASSIGN([[T]], // CHECK: TPRINT{{(<.*>)?}}([[T]]); diff --git a/test/lit/pto/tci_i16_emitc.pto b/test/lit/pto/tci_i16_emitc.pto index 4e255b7ad2..b4b072427c 100644 --- a/test/lit/pto/tci_i16_emitc.pto +++ b/test/lit/pto/tci_i16_emitc.pto @@ -4,9 +4,7 @@ module { func.func @tci_i16_kernel(%dst: memref<16xi16, #pto.address_space>) { %c0_i16 = arith.constant 0 : i16 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c16], strides: [%c16, %c1] {layout = #pto.layout} : memref<16xi16, #pto.address_space> to memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 16], strides: [16, 1] {layout = #pto.layout} : memref<16xi16, #pto.address_space> to memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c0_i16 : i16) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} diff --git a/test/lit/pto/tci_ui32_emitc.pto b/test/lit/pto/tci_ui32_emitc.pto index b08c4c1181..e36175500f 100644 --- a/test/lit/pto/tci_ui32_emitc.pto +++ b/test/lit/pto/tci_ui32_emitc.pto @@ -5,9 +5,7 @@ module { %c5_i32 = arith.constant 5 : i32 %c5_ui32 = builtin.unrealized_conversion_cast %c5_i32 : i32 to ui32 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref<32xui32, #pto.address_space> to memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 32], strides: [32, 1] {layout = #pto.layout} : memref<32xui32, #pto.address_space> to memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c5_ui32 : ui32) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} diff --git a/test/lit/pto/tprint_alloc_tile_no_rebind.pto b/test/lit/pto/tprint_alloc_tile_no_rebind.pto index c7ec3b2742..0ebf9332ed 100644 --- a/test/lit/pto/tprint_alloc_tile_no_rebind.pto +++ b/test/lit/pto/tprint_alloc_tile_no_rebind.pto @@ -14,8 +14,9 @@ module { // CHECK-LABEL: __global__ AICORE void print_kernel() { // CHECK: Tile [[TILE:v[0-9]+]]; -// CHECK: TASSIGN([[TILE]], [[ADDR:v[0-9]+]]); +// CHECK: Tile [[TILE_COPY:v[0-9]+]] = [[TILE]]; +// CHECK: TASSIGN([[TILE_COPY]], [[ADDR:v[0-9]+]]); // CHECK-NOT: TASSIGN( // CHECK-NOT: .data() // CHECK-NOT: reinterpret_cast -// CHECK: TPRINT{{(<.*>)?}}([[TILE]]); +// CHECK: TPRINT{{(<.*>)?}}([[TILE_COPY]]); diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto index 9ad19ec8f5..cb315e14f8 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto @@ -67,13 +67,15 @@ module { // CHECK-SAME: (__gm__ float* [[CUBE_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[CUBE_GM]], {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY_VAL:v[0-9]+]] = [[CUBE_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY_VAL]]); // CHECK: TSTORE -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY_VAL]]); // CHECK-LABEL: AICORE void vector_kernel // CHECK-SAME: (__gm__ float* [[VEC_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[VEC_GM]], {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY_VAL:v[0-9]+]] = [[VEC_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TLOAD // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( diff --git a/test/lit/pto/treshape_static_valid_shape_emitc.pto b/test/lit/pto/treshape_static_valid_shape_emitc.pto index 9673ec0c17..701874e5f0 100644 --- a/test/lit/pto/treshape_static_valid_shape_emitc.pto +++ b/test/lit/pto/treshape_static_valid_shape_emitc.pto @@ -32,9 +32,11 @@ module attributes {pto.target_arch = "a5"} { } // CHECK-LABEL: AICORE void treshape_static_valid_shape_emitc() -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK: TASSIGN([[SRC]], -// CHECK: Tile [[RESHAPED:v[0-9]+]]; +// CHECK: Tile [[RESHAPED_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[RESHAPED:v[0-9]+]] = [[RESHAPED_STORAGE]]; // CHECK-NEXT: TRESHAPE([[RESHAPED]], [[SRC]]); // CHECK-NOT: SetValidShape // CHECK-NOT: TRESHAPE diff --git a/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto index 7f28f9f3da..1a47f47ff5 100644 --- a/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto +++ b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto @@ -62,11 +62,11 @@ module { // CHECK-NEXT: pto.tsub ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 3 : i64, pto.last_use = array} // CHECK-NEXT: pto.tsubs ins(%{{.*}}, %{{.*}} : !pto.tile_buf, f32) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 4 : i64, pto.last_use = array} -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TEXPANDS__0" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADDS__0__1" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADD__0__1__0" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUB__0__1__1" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUBS__0__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TEXPANDS__0" +// MARKER: call_opaque "PTOAS__LAST_USE__TADDS__0__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TADD__0__1__0" +// MARKER: call_opaque "PTOAS__LAST_USE__TSUB__0__1__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TSUBS__0__1" // MARKER-NOT: pto::last_use // CPP-LABEL: AICORE void mark_last_use_slot_mask_level2( From f7c3bc9131bb8ac5f47a600028733f8b5d88c46f Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 16:50:48 +0800 Subject: [PATCH 04/51] fix: materialize EmitC branch operands as lvalues --- tools/ptoas/ptoas.cpp | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 6443f1ca72..38589c693d 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -841,6 +841,12 @@ static Attribute getDefaultEmitCVariableInitAttr(OpBuilder &builder, Type type) return Attribute{}; } +static Type getEmitCVariableStorageType(Type valueType) { + if (isa(valueType)) + return valueType; + return emitc::LValueType::get(valueType); +} + // FormExpressions may inline conditions into emitc.expression, but the C++ // emitter prints cf.br/cf.cond_br operands by variable name rather than by // recursively emitting an expression. Materialize such operands so CFG-based @@ -866,12 +872,21 @@ static void materializeControlFlowOperands(Operation *rootOp) { if (!initAttr) continue; - Value tmp = - builder.create(op->getLoc(), value.getType(), - initAttr) - .getResult(); + Value tmp = builder + .create( + op->getLoc(), getEmitCVariableStorageType(value.getType()), + initAttr) + .getResult(); builder.create(op->getLoc(), tmp, value); - operand.set(tmp); + if (auto lvalueTy = dyn_cast(tmp.getType())) { + Value loaded = builder + .create(op->getLoc(), + lvalueTy.getValueType(), tmp) + .getResult(); + operand.set(loaded); + } else { + operand.set(tmp); + } } } } From 04f6db9c1ab698cc94ef6c28848a832f34b545e6 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 19:26:56 +0800 Subject: [PATCH 05/51] fix: rebuild control-flow tile results for LLVM 21 --- .../Transforms/PTOMaterializeTileHandles.cpp | 128 ++++++++++++++++-- 1 file changed, 115 insertions(+), 13 deletions(-) diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 7b1a85c3f4..2c4b1d8de6 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -603,6 +604,17 @@ static Value lookupMaterializedTileHandle( return it->second; } +static void cloneBlockWithoutTerminator(Block *from, Block *to, + IRMapping &mapping) { + OpBuilder builder(to, to->end()); + Operation *terminator = from->getTerminator(); + for (Operation &op : *from) { + if (&op == terminator) + break; + builder.clone(op, mapping); + } +} + static FailureOr materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { bool changed = false; @@ -619,6 +631,19 @@ materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { if (!thenYield || !elseYield) continue; + if (thenYield.getNumOperands() != ifOp.getNumResults() || + elseYield.getNumOperands() != ifOp.getNumResults()) { + ifOp.emitOpError("result count does not match branch yield operands"); + return failure(); + } + + SmallVector resultTypes(ifOp->getResultTypes()); + SmallVector thenYieldOperands(thenYield.getOperands().begin(), + thenYield.getOperands().end()); + SmallVector elseYieldOperands(elseYield.getOperands().begin(), + elseYield.getOperands().end()); + SmallVector materializedResults; + for (auto [idx, result] : llvm::enumerate(ifOp.getResults())) { if (!isLocalTileMemRef(result.getType())) continue; @@ -639,12 +664,47 @@ materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { } Type tileTy = thenTile.getType(); - thenYield->setOperand(idx, thenTile); - elseYield->setOperand(idx, elseTile); - result.setType(tileTy); - tileHandles[result] = result; - changed = true; + resultTypes[idx] = tileTy; + thenYieldOperands[idx] = thenTile; + elseYieldOperands[idx] = elseTile; + materializedResults.push_back(idx); + } + + if (materializedResults.empty()) + continue; + + OpBuilder builder(ifOp); + auto newIf = builder.create( + ifOp.getLoc(), TypeRange(resultTypes), ifOp.getCondition(), + /*addThenBlock=*/true, /*addElseBlock=*/true); + newIf->setAttrs(ifOp->getAttrs()); + + IRMapping thenMapping; + cloneBlockWithoutTerminator(ifOp.thenBlock(), newIf.thenBlock(), + thenMapping); + builder.setInsertionPointToEnd(newIf.thenBlock()); + for (Value &operand : thenYieldOperands) + operand = thenMapping.lookupOrDefault(operand); + builder.create(thenYield.getLoc(), thenYieldOperands); + + IRMapping elseMapping; + cloneBlockWithoutTerminator(ifOp.elseBlock(), newIf.elseBlock(), + elseMapping); + builder.setInsertionPointToEnd(newIf.elseBlock()); + for (Value &operand : elseYieldOperands) + operand = elseMapping.lookupOrDefault(operand); + builder.create(elseYield.getLoc(), elseYieldOperands); + + for (auto [oldResult, newResult] : + llvm::zip_equal(ifOp.getResults(), newIf.getResults())) + oldResult.replaceAllUsesWith(newResult); + + for (unsigned idx : materializedResults) { + tileHandles[newIf.getResult(idx)] = newIf.getResult(idx); } + + ifOp.erase(); + changed = true; } return changed; @@ -665,6 +725,18 @@ materializeSCFForResults(ModuleOp module, DenseMap &tileHandles) { if (!yield) continue; + if (yield.getNumOperands() != forOp.getNumResults() || + forOp.getInitArgs().size() != forOp.getNumResults()) { + forOp.emitOpError("result count does not match iter/yield operands"); + return failure(); + } + + SmallVector initArgs(forOp.getInitArgs().begin(), + forOp.getInitArgs().end()); + SmallVector yieldOperands(yield.getOperands().begin(), + yield.getOperands().end()); + SmallVector materializedResults; + for (auto [idx, result] : llvm::enumerate(forOp.getResults())) { if (!isLocalTileMemRef(result.getType())) continue; @@ -692,15 +764,45 @@ materializeSCFForResults(ModuleOp module, DenseMap &tileHandles) { return failure(); } - Type tileTy = initTile.getType(); - forOp->setOperand(forOp.getNumControlOperands() + idx, initTile); - iterArg.setType(tileTy); - yield->setOperand(idx, yieldTile); - result.setType(tileTy); - tileHandles[iterArg] = iterArg; - tileHandles[result] = result; - changed = true; + initArgs[idx] = initTile; + yieldOperands[idx] = yieldTile; + materializedResults.push_back(idx); } + + if (materializedResults.empty()) + continue; + + OpBuilder builder(forOp); + auto newFor = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), initArgs); + newFor->setAttrs(forOp->getAttrs()); + + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newFor.getInductionVar()); + for (auto [oldArg, newArg] : + llvm::zip_equal(forOp.getBody()->getArguments().drop_front(), + newFor.getBody()->getArguments().drop_front())) + mapping.map(oldArg, newArg); + + cloneBlockWithoutTerminator(forOp.getBody(), newFor.getBody(), mapping); + builder.setInsertionPointToEnd(newFor.getBody()); + for (Value &operand : yieldOperands) + operand = mapping.lookupOrDefault(operand); + builder.create(yield.getLoc(), yieldOperands); + + for (auto [oldResult, newResult] : + llvm::zip_equal(forOp.getResults(), newFor.getResults())) + oldResult.replaceAllUsesWith(newResult); + + for (unsigned idx : materializedResults) { + BlockArgument newIterArg = newFor.getRegionIterArg(idx); + tileHandles[newIterArg] = newIterArg; + tileHandles[newFor.getResult(idx)] = newFor.getResult(idx); + } + + forOp.erase(); + changed = true; } return changed; From abad9494e25efddb3fa2c422bc1e669dddf592dd Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 20:04:50 +0800 Subject: [PATCH 06/51] test: isolate tile materialization IR check --- test/lit/pto/materialize_tile_handles_control_flow_result.pto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/lit/pto/materialize_tile_handles_control_flow_result.pto b/test/lit/pto/materialize_tile_handles_control_flow_result.pto index 4fba96e31b..7a4b313548 100644 --- a/test/lit/pto/materialize_tile_handles_control_flow_result.pto +++ b/test/lit/pto/materialize_tile_handles_control_flow_result.pto @@ -1,4 +1,4 @@ -// RUN: ptoas --pto-arch=a3 --mlir-print-ir-after=pto-materialize-tile-handles %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR +// RUN: ptoas --pto-arch=a3 --stop-after=pto-materialize-tile-handles --mlir-print-ir-after=pto-materialize-tile-handles %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR // RUN: ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC module { From 11852c6a2c03d8d83c248e7dfbe4af80a85f9ec5 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 20:45:38 +0800 Subject: [PATCH 07/51] fix: split PTOAS seam and EmitC pipelines --- ...alize_tile_handles_control_flow_result.pto | 2 +- tools/ptoas/ptoas.cpp | 40 ++++++++++++++----- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/test/lit/pto/materialize_tile_handles_control_flow_result.pto b/test/lit/pto/materialize_tile_handles_control_flow_result.pto index 7a4b313548..ee015bdba1 100644 --- a/test/lit/pto/materialize_tile_handles_control_flow_result.pto +++ b/test/lit/pto/materialize_tile_handles_control_flow_result.pto @@ -1,4 +1,4 @@ -// RUN: ptoas --pto-arch=a3 --stop-after=pto-materialize-tile-handles --mlir-print-ir-after=pto-materialize-tile-handles %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR +// RUN: ptoas --pto-arch=a3 --pto-print-seam-ir %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR // RUN: ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC module { diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 38589c693d..635b202a52 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -586,6 +586,11 @@ static LogicalResult emitSharedPreBackendSeamIR(ModuleOp module, return success(); } +static void printSharedPreBackendSeamIR(ModuleOp module) { + module->print(llvm::errs()); + llvm::errs() << "\n"; +} + static bool hasUnexpandedTileOps(ModuleOp module) { bool found = false; module.walk([&](Operation *op) { @@ -1942,10 +1947,8 @@ int mlir::pto::compilePTOASModule( return 1; } - if (ptoPrintSeamIR) { - module->print(llvm::errs()); - llvm::errs() << "\n"; - } + if (ptoPrintSeamIR) + printSharedPreBackendSeamIR(*module); if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) return 1; @@ -1955,15 +1958,34 @@ int mlir::pto::compilePTOASModule( context.getCANNVersionOrDefault()); } + if (failed(pm.run(*module))) { + llvm::errs() << "Error: Pass execution failed.\n"; + return 1; + } + + if (ptoPrintSeamIR) + printSharedPreBackendSeamIR(*module); + if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) + return 1; + if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { + result.kind = PTOASCompileResultKind::Text; + return 0; + } + + PassManager emitcPM(module->getContext()); + emitcPM.enableVerifier(); if (arch == "a3") { - pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); + emitcPM.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); } else { - pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A5)); + emitcPM.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A5)); } - pm.addPass(emitc::createFormExpressionsPass()); - pm.addPass(mlir::createCSEPass()); + emitcPM.addPass(emitc::createFormExpressionsPass()); + emitcPM.addPass(mlir::createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions( + emitcPM, "EmitC backend pipeline"))) + return 1; - if (failed(pm.run(*module))) { + if (failed(emitcPM.run(*module))) { llvm::errs() << "Error: Pass execution failed.\n"; return 1; } From 229492f324437f03b46a723d5ad8582e082953bd Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 22:42:35 +0800 Subject: [PATCH 08/51] fix: support LLVM 21 VPTO and EmitC lowering --- lib/PTO/Transforms/PTOToEmitC.cpp | 15 ++ lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 201 +++++++++++++++-- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 203 ++++++++++++++++-- tools/ptoas/ObjectEmission.cpp | 115 ++++++++-- 4 files changed, 478 insertions(+), 56 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 007ce9d69f..208ba76392 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -14105,6 +14105,10 @@ static AICORE inline void ptoas_auto_sync_tail( StringRef value = opaqueTy.getValue(); return value.contains("Tile<") || value.contains("ConvTile<"); }; + auto isLoweredIndexType = [](Type ty) { + auto opaqueTy = dyn_cast(ty); + return opaqueTy && opaqueTy.getValue() == "int64_t"; + }; llvm::SmallVector castsToErase; bool castCleanupFailed = false; @@ -14134,6 +14138,17 @@ static AICORE inline void ptoas_auto_sync_tail( return; } + // IndexType is lowered to int64_t for EmitC. SCF structural conversion + // can still materialize temporary index<->int64_t bridges; keeping them + // as emitc.cast leaves illegal index-typed EmitC IR for LLVM 21's C++ + // emitter, so fold the bridge back to the lowered value. + if ((isa(inTy) && isLoweredIndexType(outTy)) || + (isLoweredIndexType(inTy) && isa(outTy))) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + // SCF/CFG type conversion can transiently materialize pointer->memref // bridge casts. At this stage, the producing value is already in the // lowered EmitC pointer form; keep it and drop the bridge cast. diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index fc57a57622..1ed2263f99 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -192,7 +192,13 @@ static unsigned getNaturalByteAlignment(Type type) { } static bool hasVPTOConvertibleType(Type type) { - return isa(type); + if (isa(type)) + return true; + if (pto::isPTOLowPrecisionType(type)) + return true; + if (Type elementType = getElementTypeFromVectorLike(type)) + return pto::isPTOLowPrecisionType(elementType); + return false; } static bool hasVPTOConvertibleType(TypeRange types) { @@ -604,6 +610,55 @@ static Value castFromPayloadABI( return rewriter.create(loc, convertedType, value); } +static Type getPackedLowpScalarMemoryType(Type semanticType, + MLIRContext *context) { + if (pto::isPTOHiFloat8x2Type(semanticType)) + return IntegerType::get(context, 16); + + auto vecType = dyn_cast(semanticType); + if (!vecType || vecType.getRank() != 1 || vecType.getDimSize(0) != 2 || + llvm::is_contained(vecType.getScalableDims(), true)) + return {}; + if (!isLowpPayloadABIElementType(vecType.getElementType())) + return {}; + return IntegerType::get(context, 16); +} + +static Type getScalarAccessGEPElementType(Type semanticType, + Builder &builder) { + if (Type memoryType = + getPackedLowpScalarMemoryType(semanticType, builder.getContext())) + return memoryType; + return normalizeGEPElementTypeForLLVMLowering(semanticType, builder); +} + +static Type getScalarAccessLoadStoreType(Type semanticType, + Type convertedType, + MLIRContext *context) { + if (Type memoryType = getPackedLowpScalarMemoryType(semanticType, context)) + return memoryType; + return convertedType; +} + +static Value castToScalarAccessMemoryType(Location loc, Value value, + Type semanticType, + ConversionPatternRewriter &rewriter) { + Type memoryType = + getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()); + if (!memoryType || memoryType == value.getType()) + return value; + return rewriter.create(loc, memoryType, value); +} + +static Value castFromScalarAccessMemoryType( + Location loc, Value value, Type semanticType, Type convertedType, + ConversionPatternRewriter &rewriter) { + if (!getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()) || + value.getType() == convertedType) + return value; + return rewriter.create(loc, convertedType, value); +} + static std::string getAtomicElementTypeFragment(Type type, Attribute signednessAttr) { if (auto vecType = dyn_cast(type)) { @@ -9434,6 +9489,8 @@ class ConvertPtoLoadScalarOp final if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load_scalar result type"); + Type loadValueType = getScalarAccessLoadStoreType( + op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9443,19 +9500,42 @@ class ConvertPtoLoadScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - normalizeGEPElementTypeForLLVMLowering( - convertedValueType, rewriter), + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.replaceOpWithNewOp( - op, convertedValueType, elemPtr, - getNaturalByteAlignment(convertedValueType)); + auto loaded = rewriter.create( + op.getLoc(), loadValueType, elemPtr, + getNaturalByteAlignment(loadValueType)); + Value result = castFromScalarAccessMemoryType( + op.getLoc(), loaded.getResult(), op.getValue().getType(), + convertedValueType, rewriter); + rewriter.replaceOp(op, result); return success(); } }; +static FailureOr recoverConvertedValue(Value value, Type sourceType, + const TypeConverter &converter) { + Type convertedType = converter.convertType(sourceType); + if (!convertedType) + return failure(); + + for (unsigned depth = 0; depth < 4; ++depth) { + if (value.getType() == convertedType) + return value; + auto castOp = value.getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1 || + castOp->getNumResults() != 1) + break; + value = castOp.getOperand(0); + } + return failure(); +} + class ConvertPtoStoreScalarOp final : public OpConversionPattern { public: @@ -9476,14 +9556,22 @@ class ConvertPtoStoreScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - normalizeGEPElementTypeForLLVMLowering( - adaptor.getValue().getType(), + getScalarAccessGEPElementType( + op.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, - getNaturalByteAlignment(adaptor.getValue().getType())); + FailureOr value = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(value)) + return rewriter.notifyMatchFailure(op, "could not convert store value"); + + Value storedValue = castToScalarAccessMemoryType( + op.getLoc(), *value, op.getValue().getType(), rewriter); + rewriter.create( + op.getLoc(), storedValue, elemPtr, + getNaturalByteAlignment(storedValue.getType())); rewriter.eraseOp(op); return success(); } @@ -9506,6 +9594,8 @@ class ConvertPtoLoadOp final : public OpConversionPattern { getTypeConverter()->convertType(op.getValue().getType()); if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load result type"); + Type loadValueType = getScalarAccessLoadStoreType( + op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9515,14 +9605,20 @@ class ConvertPtoLoadOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - convertedValueType, + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.replaceOpWithNewOp( - op, convertedValueType, elemPtr, - getNaturalByteAlignment(convertedValueType)); + auto loaded = rewriter.create( + op.getLoc(), loadValueType, elemPtr, + getNaturalByteAlignment(loadValueType)); + Value result = castFromScalarAccessMemoryType( + op.getLoc(), loaded.getResult(), op.getValue().getType(), + convertedValueType, rewriter); + rewriter.replaceOp(op, result); return success(); } }; @@ -9604,7 +9700,9 @@ class ConvertPtoLdgOp final : public OpConversionPattern { ValueRange{offset}); } - auto ptrTy = cast(op.getPtr().getType()); + auto ptrTy = dyn_cast(op.getPtr().getType()); + if (!ptrTy) + return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9665,13 +9763,22 @@ class ConvertPtoStoreOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - adaptor.getValue().getType(), + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } + FailureOr value = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(value)) + return rewriter.notifyMatchFailure(op, "could not convert store value"); + + Value storedValue = castToScalarAccessMemoryType( + op.getLoc(), *value, op.getValue().getType(), rewriter); rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), elemPtr, - getNaturalByteAlignment(adaptor.getValue().getType())); + op, storedValue, elemPtr, + getNaturalByteAlignment(storedValue.getType())); return success(); } }; @@ -9729,7 +9836,9 @@ class ConvertPtoStgOp final : public OpConversionPattern { adaptor.getPtr(), ValueRange{offset}); } - auto ptrTy = cast(op.getPtr().getType()); + auto ptrTy = dyn_cast(op.getPtr().getType()); + if (!ptrTy) + return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9749,8 +9858,12 @@ class ConvertPtoStgOp final : public OpConversionPattern { : pto::StL2Cache::NMFV; Value modeValue = getI32Constant(rewriter, op.getLoc(), static_cast(mode)); + FailureOr convertedValue = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(convertedValue)) + return rewriter.notifyMatchFailure(op, "could not convert stg value"); Value storedValue = convertStgValue(op.getLoc(), op.getValue().getType(), - adaptor.getValue(), rewriter); + *convertedValue, rewriter); auto funcType = rewriter.getFunctionType(TypeRange{ptr->getType(), storedValue.getType(), rewriter.getI32Type()}, @@ -9775,7 +9888,9 @@ class ConvertVPTOTypedCarrierOp final : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (isa(op)) + if (isa(op)) return failure(); if (!hasVPTOConvertibleType(op->getOperandTypes()) && !hasVPTOConvertibleType(op->getResultTypes())) @@ -10234,6 +10349,49 @@ static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { return type; } +static bool isI8VectorToLowpVectorMaterialization( + UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return false; + + auto sourceVec = dyn_cast(castOp.getOperand(0).getType()); + auto resultVec = dyn_cast(castOp.getResult(0).getType()); + if (!sourceVec || !resultVec || sourceVec.getShape() != resultVec.getShape() || + sourceVec.getScalableDims() != resultVec.getScalableDims()) + return false; + + auto sourceElement = dyn_cast(sourceVec.getElementType()); + return sourceElement && sourceElement.getWidth() == 8 && + pto::isPTOLowPrecisionType(resultVec.getElementType()); +} + +static void foldLowpVectorMaterializationCastsForLLVMExport(ModuleOp module) { + SmallVector casts; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (isI8VectorToLowpVectorMaterialization(castOp)) + casts.push_back(castOp); + }); + + for (UnrealizedConversionCastOp castOp : casts) { + if (!castOp) + continue; + SmallVector users(castOp->getUsers()); + for (Operation *user : users) { + auto bitcastOp = dyn_cast(user); + if (!bitcastOp) + continue; + OpBuilder builder(bitcastOp); + Value replacement = builder.create( + bitcastOp.getLoc(), bitcastOp.getResult().getType(), + castOp.getOperand(0)); + bitcastOp.getResult().replaceAllUsesWith(replacement); + bitcastOp.erase(); + } + if (castOp->use_empty()) + castOp.erase(); + } +} + static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { Builder builder(module.getContext()); @@ -10551,6 +10709,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; return failure(); } + foldLowpVectorMaterializationCastsForLLVMExport(clonedModule); return emit(clonedModule); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index bd432c62b4..fbd0298c6a 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -193,7 +193,13 @@ static unsigned getNaturalByteAlignment(Type type) { } static bool hasVPTOConvertibleType(Type type) { - return isa(type); + if (isa(type)) + return true; + if (pto::isPTOLowPrecisionType(type)) + return true; + if (Type elementType = getElementTypeFromVectorLike(type)) + return pto::isPTOLowPrecisionType(elementType); + return false; } static bool hasVPTOConvertibleType(TypeRange types) { @@ -559,6 +565,55 @@ static Value castFromPayloadABI( return rewriter.create(loc, convertedType, value); } +static Type getPackedLowpScalarMemoryType(Type semanticType, + MLIRContext *context) { + if (pto::isPTOHiFloat8x2Type(semanticType)) + return IntegerType::get(context, 16); + + auto vecType = dyn_cast(semanticType); + if (!vecType || vecType.getRank() != 1 || vecType.getDimSize(0) != 2 || + llvm::is_contained(vecType.getScalableDims(), true)) + return {}; + if (!isLowpPayloadABIElementType(vecType.getElementType())) + return {}; + return IntegerType::get(context, 16); +} + +static Type getScalarAccessGEPElementType(Type semanticType, + Builder &builder) { + if (Type memoryType = + getPackedLowpScalarMemoryType(semanticType, builder.getContext())) + return memoryType; + return normalizeGEPElementTypeForLLVMLowering(semanticType, builder); +} + +static Type getScalarAccessLoadStoreType(Type semanticType, + Type convertedType, + MLIRContext *context) { + if (Type memoryType = getPackedLowpScalarMemoryType(semanticType, context)) + return memoryType; + return convertedType; +} + +static Value castToScalarAccessMemoryType(Location loc, Value value, + Type semanticType, + ConversionPatternRewriter &rewriter) { + Type memoryType = + getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()); + if (!memoryType || memoryType == value.getType()) + return value; + return rewriter.create(loc, memoryType, value); +} + +static Value castFromScalarAccessMemoryType( + Location loc, Value value, Type semanticType, Type convertedType, + ConversionPatternRewriter &rewriter) { + if (!getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()) || + value.getType() == convertedType) + return value; + return rewriter.create(loc, convertedType, value); +} + static std::string getAtomicElementTypeFragment(Type type, Attribute signednessAttr) { if (auto vecType = dyn_cast(type)) { @@ -9378,6 +9433,8 @@ class ConvertPtoLoadScalarOp final if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load_scalar result type"); + Type loadValueType = getScalarAccessLoadStoreType( + op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9387,19 +9444,42 @@ class ConvertPtoLoadScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - normalizeGEPElementTypeForLLVMLowering( - convertedValueType, rewriter), + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.replaceOpWithNewOp( - op, convertedValueType, elemPtr, - getNaturalByteAlignment(convertedValueType)); + auto loaded = rewriter.create( + op.getLoc(), loadValueType, elemPtr, + getNaturalByteAlignment(loadValueType)); + Value result = castFromScalarAccessMemoryType( + op.getLoc(), loaded.getResult(), op.getValue().getType(), + convertedValueType, rewriter); + rewriter.replaceOp(op, result); return success(); } }; +static FailureOr recoverConvertedValue(Value value, Type sourceType, + const TypeConverter &converter) { + Type convertedType = converter.convertType(sourceType); + if (!convertedType) + return failure(); + + for (unsigned depth = 0; depth < 4; ++depth) { + if (value.getType() == convertedType) + return value; + auto castOp = value.getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1 || + castOp->getNumResults() != 1) + break; + value = castOp.getOperand(0); + } + return failure(); +} + class ConvertPtoStoreScalarOp final : public OpConversionPattern { public: @@ -9420,14 +9500,22 @@ class ConvertPtoStoreScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - normalizeGEPElementTypeForLLVMLowering( - adaptor.getValue().getType(), + getScalarAccessGEPElementType( + op.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, - getNaturalByteAlignment(adaptor.getValue().getType())); + FailureOr value = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(value)) + return rewriter.notifyMatchFailure(op, "could not convert store value"); + + Value storedValue = castToScalarAccessMemoryType( + op.getLoc(), *value, op.getValue().getType(), rewriter); + rewriter.create( + op.getLoc(), storedValue, elemPtr, + getNaturalByteAlignment(storedValue.getType())); rewriter.eraseOp(op); return success(); } @@ -9450,6 +9538,8 @@ class ConvertPtoLoadOp final : public OpConversionPattern { getTypeConverter()->convertType(op.getValue().getType()); if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load result type"); + Type loadValueType = getScalarAccessLoadStoreType( + op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9459,14 +9549,20 @@ class ConvertPtoLoadOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - convertedValueType, + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } - rewriter.replaceOpWithNewOp( - op, convertedValueType, elemPtr, - getNaturalByteAlignment(convertedValueType)); + auto loaded = rewriter.create( + op.getLoc(), loadValueType, elemPtr, + getNaturalByteAlignment(loadValueType)); + Value result = castFromScalarAccessMemoryType( + op.getLoc(), loaded.getResult(), op.getValue().getType(), + convertedValueType, rewriter); + rewriter.replaceOp(op, result); return success(); } @@ -9549,7 +9645,9 @@ class ConvertPtoLdgOp final : public OpConversionPattern { ValueRange{offset}); } - auto ptrTy = cast(op.getPtr().getType()); + auto ptrTy = dyn_cast(op.getPtr().getType()); + if (!ptrTy) + return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9610,13 +9708,22 @@ class ConvertPtoStoreOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - adaptor.getValue().getType(), + getScalarAccessGEPElementType( + op.getValue().getType(), + rewriter), adaptor.getPtr(), ValueRange{offset}); } + FailureOr value = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(value)) + return rewriter.notifyMatchFailure(op, "could not convert store value"); + + Value storedValue = castToScalarAccessMemoryType( + op.getLoc(), *value, op.getValue().getType(), rewriter); rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), elemPtr, - getNaturalByteAlignment(adaptor.getValue().getType())); + op, storedValue, elemPtr, + getNaturalByteAlignment(storedValue.getType())); return success(); } @@ -9670,12 +9777,14 @@ class ConvertPtoStgOp final : public OpConversionPattern { if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, normalizeGEPElementTypeForLLVMLowering( - adaptor.getValue().getType(), + op.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - auto ptrTy = cast(op.getPtr().getType()); + auto ptrTy = dyn_cast(op.getPtr().getType()); + if (!ptrTy) + return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9695,8 +9804,12 @@ class ConvertPtoStgOp final : public OpConversionPattern { : pto::StL2Cache::NMFV; Value modeValue = getI32Constant(rewriter, op.getLoc(), static_cast(mode)); + FailureOr convertedValue = recoverConvertedValue( + adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); + if (failed(convertedValue)) + return rewriter.notifyMatchFailure(op, "could not convert stg value"); Value storedValue = convertStgValue(op.getLoc(), op.getValue().getType(), - adaptor.getValue(), rewriter); + *convertedValue, rewriter); auto funcType = rewriter.getFunctionType(TypeRange{ptr->getType(), storedValue.getType(), rewriter.getI32Type()}, @@ -9721,7 +9834,9 @@ class ConvertVPTOTypedCarrierOp final : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (isa(op)) + if (isa(op)) return failure(); if (!hasVPTOConvertibleType(op->getOperandTypes()) && !hasVPTOConvertibleType(op->getResultTypes())) @@ -10180,6 +10295,49 @@ static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { return type; } +static bool isI8VectorToLowpVectorMaterialization( + UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return false; + + auto sourceVec = dyn_cast(castOp.getOperand(0).getType()); + auto resultVec = dyn_cast(castOp.getResult(0).getType()); + if (!sourceVec || !resultVec || sourceVec.getShape() != resultVec.getShape() || + sourceVec.getScalableDims() != resultVec.getScalableDims()) + return false; + + auto sourceElement = dyn_cast(sourceVec.getElementType()); + return sourceElement && sourceElement.getWidth() == 8 && + pto::isPTOLowPrecisionType(resultVec.getElementType()); +} + +static void foldLowpVectorMaterializationCastsForLLVMExport(ModuleOp module) { + SmallVector casts; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (isI8VectorToLowpVectorMaterialization(castOp)) + casts.push_back(castOp); + }); + + for (UnrealizedConversionCastOp castOp : casts) { + if (!castOp) + continue; + SmallVector users(castOp->getUsers()); + for (Operation *user : users) { + auto bitcastOp = dyn_cast(user); + if (!bitcastOp) + continue; + OpBuilder builder(bitcastOp); + Value replacement = builder.create( + bitcastOp.getLoc(), bitcastOp.getResult().getType(), + castOp.getOperand(0)); + bitcastOp.getResult().replaceAllUsesWith(replacement); + bitcastOp.erase(); + } + if (castOp->use_empty()) + castOp.erase(); + } +} + static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { Builder builder(module.getContext()); @@ -10512,6 +10670,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; return failure(); } + foldLowpVectorMaterializationCastsForLLVMExport(clonedModule); return emit(clonedModule); } diff --git a/tools/ptoas/ObjectEmission.cpp b/tools/ptoas/ObjectEmission.cpp index cd85645cd6..41230335e5 100644 --- a/tools/ptoas/ObjectEmission.cpp +++ b/tools/ptoas/ObjectEmission.cpp @@ -72,24 +72,113 @@ static void stripUnsupportedBishengAttrs(llvm::Module &module) { } } +static std::optional findVectorTypeStart(StringRef text, + size_t typeEnd) { + unsigned depth = 0; + for (size_t index = typeEnd; index > 0; --index) { + char c = text[index - 1]; + if (c == '\n' || c == '\r') + return std::nullopt; + if (c == '>') + ++depth; + else if (c == '<') { + if (depth == 0) + return std::nullopt; + --depth; + if (depth == 0) + return index - 1; + } + } + return std::nullopt; +} + +static std::optional getFixedVectorElementCount(StringRef vectorType) { + vectorType = vectorType.trim(); + if (!vectorType.consume_front("<") || !vectorType.consume_back(">")) + return std::nullopt; + vectorType = vectorType.trim(); + if (vectorType.starts_with("vscale")) + return std::nullopt; + + auto [countText, elementType] = vectorType.split(" x "); + unsigned count = 0; + if (countText.empty() || elementType.empty() || + countText.getAsInteger(10, count) || count == 0) + return std::nullopt; + return count; +} + +static std::string expandFixedVectorSplatConstants(StringRef input) { + constexpr StringRef marker = "splat ("; + constexpr unsigned maxExpandedElements = 4096; + + std::string output; + size_t cursor = 0; + size_t searchFrom = 0; + + while (true) { + size_t splatPos = input.find(marker, searchFrom); + if (splatPos == StringRef::npos) + break; + + size_t typeEnd = splatPos; + while (typeEnd > 0 && + std::isspace(static_cast(input[typeEnd - 1]))) + --typeEnd; + if (typeEnd == 0 || input[typeEnd - 1] != '>') { + searchFrom = splatPos + marker.size(); + continue; + } + + std::optional typeStart = findVectorTypeStart(input, typeEnd); + if (!typeStart) { + searchFrom = splatPos + marker.size(); + continue; + } + + std::optional elementCount = + getFixedVectorElementCount(input.slice(*typeStart, typeEnd)); + if (!elementCount || *elementCount > maxExpandedElements) { + searchFrom = splatPos + marker.size(); + continue; + } + + size_t valueStart = splatPos + marker.size(); + size_t valueEnd = input.find(')', valueStart); + size_t lineEnd = input.find('\n', valueStart); + if (valueEnd == StringRef::npos || + (lineEnd != StringRef::npos && valueEnd > lineEnd)) { + searchFrom = splatPos + marker.size(); + continue; + } + + StringRef element = input.slice(valueStart, valueEnd); + output.append(input.data() + cursor, typeEnd - cursor); + output.append(" <"); + for (unsigned index = 0; index < *elementCount; ++index) { + if (index != 0) + output.append(", "); + output.append(element.data(), element.size()); + } + output.append(">"); + + cursor = valueEnd + 1; + searchFrom = cursor; + } + + output.append(input.data() + cursor, input.size() - cursor); + return output; +} + static bool writeLLVMModuleFile(llvm::Module &module, StringRef path, llvm::raw_ostream &diagOS) { - std::error_code ec; - llvm::raw_fd_ostream os(path, ec, llvm::sys::fs::OF_Text); - if (ec) { - diagOS << "Error: failed to open " << path << " for write: " - << ec.message() << "\n"; - return false; - } stripUnsupportedBishengAttrs(module); + std::string llvmIR; + llvm::raw_string_ostream os(llvmIR); module.print(os, nullptr); os.flush(); - if (os.has_error()) { - diagOS << "Error: failed to write LLVM module to " << path << "\n"; - os.clear_error(); - return false; - } - return true; + llvmIR = expandFixedVectorSplatConstants(llvmIR); + return writeTextFile(path, llvmIR, diagOS); } static std::string sanitizeModuleId(llvm::StringRef raw) { From 0e8ff202a9587a50b4bb9c171663c771cb6a9791 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 23:36:47 +0800 Subject: [PATCH 09/51] fix: avoid LLVM 21 EmitC integer attr assertion --- tools/ptoas/ptoas.cpp | 174 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 635b202a52..545909cd83 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -834,6 +834,179 @@ static void dropEmptyEmitCExpressions(Operation *rootOp) { expr.erase(); } +static void appendEmitCIntegerAttrLiteral(std::string &storage, + const APInt &value, bool isUnsigned) { + if (value.getBitWidth() == 0) { + storage.append("0"); + return; + } + if (value.getBitWidth() == 1) { + storage.append(value.getBoolValue() ? "true" : "false"); + return; + } + + SmallString<128> strValue; + value.toString(strValue, 10, !isUnsigned, false); + storage.append(strValue.data(), strValue.size()); +} + +static bool shouldPrintEmitCIntegerAttrAsUnsigned(IntegerAttr attr) { + auto intTy = dyn_cast(attr.getType()); + return intTy && intTy.getSignedness() == IntegerType::Unsigned; +} + +static std::string getEmitCIntegerAttrLiteral(IntegerAttr attr) { + std::string literal; + appendEmitCIntegerAttrLiteral(literal, attr.getValue(), + shouldPrintEmitCIntegerAttrAsUnsigned(attr)); + return literal; +} + +static std::optional +getEmitCDenseIntElementsAttrLiteral(DenseIntElementsAttr attr) { + auto tensorTy = dyn_cast(attr.getType()); + if (!tensorTy) + return std::nullopt; + + Type elementType = tensorTy.getElementType(); + bool isUnsigned = false; + if (auto intTy = dyn_cast(elementType)) { + isUnsigned = intTy.getSignedness() == IntegerType::Unsigned; + } else if (!isa(elementType)) { + return std::nullopt; + } + + std::string literal; + literal.push_back('{'); + bool first = true; + for (const APInt &value : attr) { + if (!first) + literal.append(", "); + first = false; + appendEmitCIntegerAttrLiteral(literal, value, isUnsigned); + } + literal.push_back('}'); + return literal; +} + +static Attribute normalizeEmitCPrintedIntAttrForCppEmission(MLIRContext *ctx, + Attribute attr) { + if (auto intAttr = dyn_cast(attr)) + return emitc::OpaqueAttr::get(ctx, getEmitCIntegerAttrLiteral(intAttr)); + + if (auto denseAttr = dyn_cast(attr)) { + if (std::optional literal = + getEmitCDenseIntElementsAttrLiteral(denseAttr)) + return emitc::OpaqueAttr::get(ctx, *literal); + } + + return attr; +} + +static IntegerAttr normalizeEmitCIndexPlaceholderAttr(MLIRContext *ctx, + IntegerAttr attr) { + const APInt &value = attr.getValue(); + int64_t index = value.getBitWidth() == 0 ? 0 : value.getSExtValue(); + return IntegerAttr::get(IndexType::get(ctx), APInt(64, index)); +} + +static ArrayAttr normalizeEmitCCallArgsForCppEmission(MLIRContext *ctx, + ArrayAttr args) { + SmallVector normalized; + normalized.reserve(args.size()); + bool changed = false; + + for (Attribute attr : args) { + if (auto intAttr = dyn_cast(attr)) { + if (isa(intAttr.getType())) { + Attribute normalizedAttr = + normalizeEmitCIndexPlaceholderAttr(ctx, intAttr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + continue; + } + + Attribute normalizedAttr = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + continue; + } + + Attribute normalizedAttr = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + } + + return changed ? ArrayAttr::get(ctx, normalized) : args; +} + +static ArrayAttr normalizeEmitCTemplateArgsForCppEmission(MLIRContext *ctx, + ArrayAttr args) { + SmallVector normalized; + normalized.reserve(args.size()); + bool changed = false; + + for (Attribute attr : args) { + Attribute normalizedAttr = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + } + + return changed ? ArrayAttr::get(ctx, normalized) : args; +} + +static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { + MLIRContext *ctx = rootOp->getContext(); + rootOp->walk([&](Operation *op) { + if (auto constant = dyn_cast(op)) { + Attribute value = constant.getValue(); + Attribute normalized = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, value); + if (normalized != value) + constant.getProperties().setValue(normalized); + return; + } + + if (auto variable = dyn_cast(op)) { + Attribute value = variable.getValue(); + Attribute normalized = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, value); + if (normalized != value) + variable.getProperties().setValue(normalized); + return; + } + + if (auto global = dyn_cast(op)) { + std::optional initialValue = global.getInitialValue(); + if (!initialValue) + return; + Attribute normalized = + normalizeEmitCPrintedIntAttrForCppEmission(ctx, *initialValue); + if (normalized != *initialValue) + global.getProperties().setInitialValue(normalized); + return; + } + + if (auto call = dyn_cast(op)) { + if (std::optional args = call.getArgs()) { + ArrayAttr normalized = normalizeEmitCCallArgsForCppEmission(ctx, *args); + if (normalized != *args) + call.getProperties().setArgs(normalized); + } + if (std::optional templateArgs = call.getTemplateArgs()) { + ArrayAttr normalized = + normalizeEmitCTemplateArgsForCppEmission(ctx, *templateArgs); + if (normalized != *templateArgs) + call.getProperties().setTemplateArgs(normalized); + } + return; + } + }); +} + static Attribute getDefaultEmitCVariableInitAttr(OpBuilder &builder, Type type) { if (auto intTy = dyn_cast(type)) return builder.getIntegerAttr(intTy, 0); @@ -1992,6 +2165,7 @@ int mlir::pto::compilePTOASModule( dropEmptyEmitCExpressions(module.get()); materializeControlFlowOperands(module.get()); + normalizeEmitCIntegerAttrsForCppEmission(module.get()); if (failed(reorderEmitCFunctions(module.get()))) { llvm::errs() << "Error: Failed to order emitted functions for C++ emission.\n"; return 1; From df57b7e430a442579f8a7084b455c3a5b2ae680d Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 23:41:57 +0800 Subject: [PATCH 10/51] ci: use system compiler for VPTO sim build --- .github/workflows/ci_sim.yml | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index cd41b22ba8..835e1f6703 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -162,7 +162,7 @@ jobs: run: | set -euo pipefail missing_tools=() - for tool in python3 git cmake ninja make; do + for tool in python3 git cmake ninja make cc c++; do if ! command -v "${tool}" >/dev/null 2>&1; then missing_tools+=("${tool}") fi @@ -171,7 +171,9 @@ jobs: if [[ "${#missing_tools[@]}" -gt 0 ]]; then if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then sudo apt-get update - sudo apt-get install -y python3 python3-pip git cmake ninja-build make + sudo apt-get install -y \ + python3 python3-pip git cmake ninja-build make \ + build-essential zlib1g-dev libzstd-dev else echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 echo "ERROR: automatic installation requires sudo + apt-get" >&2 @@ -179,6 +181,11 @@ jobs: fi fi + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y zlib1g-dev libzstd-dev + fi + python3 -m pip --version >/dev/null 2>&1 || { if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then sudo apt-get update @@ -201,6 +208,25 @@ jobs: python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi + if [[ -x /usr/bin/cc ]]; then + c_compiler=/usr/bin/cc + elif [[ -n "${CC:-}" && -x "${CC}" ]]; then + c_compiler="${CC}" + else + c_compiler="$(command -v cc)" + fi + if [[ -x /usr/bin/c++ ]]; then + cxx_compiler=/usr/bin/c++ + elif [[ -n "${CXX:-}" && -x "${CXX}" ]]; then + cxx_compiler="${CXX}" + else + cxx_compiler="$(command -v c++)" + fi + echo "PTOAS_CMAKE_C_COMPILER=${c_compiler}" >> "${GITHUB_ENV}" + echo "PTOAS_CMAKE_CXX_COMPILER=${cxx_compiler}" >> "${GITHUB_ENV}" + "${c_compiler}" --version | head -n 1 + "${cxx_compiler}" --version | head -n 1 + - name: Clean CI work dirs shell: bash run: | @@ -278,6 +304,8 @@ jobs: export CXX=g++ # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). LLVM_BUILD_DIR="${LLVM_DIR}" \ + CMAKE_C_COMPILER="${PTOAS_CMAKE_C_COMPILER}" \ + CMAKE_CXX_COMPILER="${PTOAS_CMAKE_CXX_COMPILER}" \ PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" \ python3 -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" From 6d3447717aeaac89700d31c54b829c4020cc5633 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 11 Jun 2026 23:56:56 +0800 Subject: [PATCH 11/51] ci: install VPTO sim runner packages on demand --- .github/workflows/ci_sim.yml | 58 +++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 835e1f6703..5067cdb577 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -161,6 +161,29 @@ jobs: shell: bash run: | set -euo pipefail + + apt_packages=() + add_apt_package() { + local package="$1" + local existing + for existing in "${apt_packages[@]}"; do + [[ "${existing}" == "${package}" ]] && return 0 + done + apt_packages+=("${package}") + } + install_apt_packages() { + if [[ "$#" -eq 0 ]]; then + return 0 + fi + if ! command -v sudo >/dev/null 2>&1 || ! command -v apt-get >/dev/null 2>&1; then + echo "ERROR: missing required tools/packages and automatic installation requires sudo + apt-get" >&2 + exit 1 + fi + sudo apt-get -o DPkg::Lock::Timeout=60 update + sudo env DEBIAN_FRONTEND=noninteractive \ + apt-get -o DPkg::Lock::Timeout=60 install -y "$@" + } + missing_tools=() for tool in python3 git cmake ninja make cc c++; do if ! command -v "${tool}" >/dev/null 2>&1; then @@ -169,31 +192,24 @@ jobs: done if [[ "${#missing_tools[@]}" -gt 0 ]]; then - if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then - sudo apt-get update - sudo apt-get install -y \ - python3 python3-pip git cmake ninja-build make \ - build-essential zlib1g-dev libzstd-dev - else - echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 - echo "ERROR: automatic installation requires sudo + apt-get" >&2 - exit 1 - fi + for tool in "${missing_tools[@]}"; do + case "${tool}" in + python3) add_apt_package python3 ;; + git) add_apt_package git ;; + cmake) add_apt_package cmake ;; + ninja) add_apt_package ninja-build ;; + make) add_apt_package make ;; + cc|c++) add_apt_package build-essential ;; + esac + done fi - if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then - sudo apt-get update - sudo apt-get install -y zlib1g-dev libzstd-dev - fi + [[ -r /usr/include/zlib.h ]] || add_apt_package zlib1g-dev + [[ -r /usr/include/zstd.h ]] || add_apt_package libzstd-dev + install_apt_packages "${apt_packages[@]}" python3 -m pip --version >/dev/null 2>&1 || { - if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then - sudo apt-get update - sudo apt-get install -y python3-pip - else - echo "ERROR: python3-pip is required on self-hosted runner" >&2 - exit 1 - fi + install_apt_packages python3-pip } need_pip_install=0 From 2063df8ee3cc7d8b324c3a1014e4c4999bcc3a45 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 00:09:23 +0800 Subject: [PATCH 12/51] ci: avoid apt installs in VPTO sim dependency check --- .github/workflows/ci_sim.yml | 55 +++++++++++------------------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 5067cdb577..b20507ab89 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -161,55 +161,32 @@ jobs: shell: bash run: | set -euo pipefail - - apt_packages=() - add_apt_package() { - local package="$1" - local existing - for existing in "${apt_packages[@]}"; do - [[ "${existing}" == "${package}" ]] && return 0 - done - apt_packages+=("${package}") - } - install_apt_packages() { - if [[ "$#" -eq 0 ]]; then - return 0 - fi - if ! command -v sudo >/dev/null 2>&1 || ! command -v apt-get >/dev/null 2>&1; then - echo "ERROR: missing required tools/packages and automatic installation requires sudo + apt-get" >&2 - exit 1 - fi - sudo apt-get -o DPkg::Lock::Timeout=60 update - sudo env DEBIAN_FRONTEND=noninteractive \ - apt-get -o DPkg::Lock::Timeout=60 install -y "$@" - } - missing_tools=() - for tool in python3 git cmake ninja make cc c++; do + for tool in python3 git cmake ninja make; do if ! command -v "${tool}" >/dev/null 2>&1; then missing_tools+=("${tool}") fi done if [[ "${#missing_tools[@]}" -gt 0 ]]; then - for tool in "${missing_tools[@]}"; do - case "${tool}" in - python3) add_apt_package python3 ;; - git) add_apt_package git ;; - cmake) add_apt_package cmake ;; - ninja) add_apt_package ninja-build ;; - make) add_apt_package make ;; - cc|c++) add_apt_package build-essential ;; - esac - done + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3 python3-pip git cmake ninja-build make + else + echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 + echo "ERROR: automatic installation requires sudo + apt-get" >&2 + exit 1 + fi fi - [[ -r /usr/include/zlib.h ]] || add_apt_package zlib1g-dev - [[ -r /usr/include/zstd.h ]] || add_apt_package libzstd-dev - install_apt_packages "${apt_packages[@]}" - python3 -m pip --version >/dev/null 2>&1 || { - install_apt_packages python3-pip + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3-pip + else + echo "ERROR: python3-pip is required on self-hosted runner" >&2 + exit 1 + fi } need_pip_install=0 From 1b46b0a086289291bdb180a62dedc84314dc3582 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 00:49:59 +0800 Subject: [PATCH 13/51] fix: guard EmitC integer APInt handling --- lib/PTO/Transforms/PTOToEmitC.cpp | 13 +++++++++---- tools/ptoas/ptoas.cpp | 30 ++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 208ba76392..4b004eb99d 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2827,6 +2827,11 @@ static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); } +static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { + const APInt &value = attr.getValue(); + return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); +} + static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, Attribute valueAttr) { auto opaqueTy = dyn_cast(targetType); @@ -3128,7 +3133,7 @@ struct ArithConstantToEmitC : public OpConversionPattern { } if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + std::string valStr = std::to_string(getIntegerAttrSignedValue(intAttr)); auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); rewriter.replaceOpWithNewOp(op, newType, constAttr); return success(); @@ -3565,7 +3570,7 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } if (auto attr = ofr.dyn_cast()) { if (auto ia = dyn_cast(attr)) - return mkIndex(ia.getValue().getSExtValue()); + return mkIndex(getIntegerAttrSignedValue(ia)); } return mkIndex(0); }; @@ -4455,7 +4460,7 @@ struct PointerCastConversion : public OpConversionPattern { static bool getIndexConst(Value v, int64_t &out) { if (auto cst = v.getDefiningOp()) { if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); + out = getIntegerAttrSignedValue(ia); return true; } } @@ -11819,7 +11824,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { return false; if (auto cst = v.getDefiningOp()) { if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); + out = getIntegerAttrSignedValue(ia); return true; } } diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 545909cd83..7f98d28122 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -889,8 +889,8 @@ getEmitCDenseIntElementsAttrLiteral(DenseIntElementsAttr attr) { return literal; } -static Attribute normalizeEmitCPrintedIntAttrForCppEmission(MLIRContext *ctx, - Attribute attr) { +static Attribute normalizeEmitCPrintedAttrForCppEmission(MLIRContext *ctx, + Attribute attr) { if (auto intAttr = dyn_cast(attr)) return emitc::OpaqueAttr::get(ctx, getEmitCIntegerAttrLiteral(intAttr)); @@ -900,6 +900,20 @@ static Attribute normalizeEmitCPrintedIntAttrForCppEmission(MLIRContext *ctx, return emitc::OpaqueAttr::get(ctx, *literal); } + if (auto arrayAttr = dyn_cast(attr)) { + SmallVector normalized; + normalized.reserve(arrayAttr.size()); + bool changed = false; + for (Attribute element : arrayAttr) { + Attribute normalizedElement = + normalizeEmitCPrintedAttrForCppEmission(ctx, element); + changed |= normalizedElement != element; + normalized.push_back(normalizedElement); + } + if (changed) + return ArrayAttr::get(ctx, normalized); + } + return attr; } @@ -927,14 +941,14 @@ static ArrayAttr normalizeEmitCCallArgsForCppEmission(MLIRContext *ctx, } Attribute normalizedAttr = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); changed |= normalizedAttr != attr; normalized.push_back(normalizedAttr); continue; } Attribute normalizedAttr = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); changed |= normalizedAttr != attr; normalized.push_back(normalizedAttr); } @@ -950,7 +964,7 @@ static ArrayAttr normalizeEmitCTemplateArgsForCppEmission(MLIRContext *ctx, for (Attribute attr : args) { Attribute normalizedAttr = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, attr); + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); changed |= normalizedAttr != attr; normalized.push_back(normalizedAttr); } @@ -964,7 +978,7 @@ static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { if (auto constant = dyn_cast(op)) { Attribute value = constant.getValue(); Attribute normalized = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, value); + normalizeEmitCPrintedAttrForCppEmission(ctx, value); if (normalized != value) constant.getProperties().setValue(normalized); return; @@ -973,7 +987,7 @@ static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { if (auto variable = dyn_cast(op)) { Attribute value = variable.getValue(); Attribute normalized = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, value); + normalizeEmitCPrintedAttrForCppEmission(ctx, value); if (normalized != value) variable.getProperties().setValue(normalized); return; @@ -984,7 +998,7 @@ static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { if (!initialValue) return; Attribute normalized = - normalizeEmitCPrintedIntAttrForCppEmission(ctx, *initialValue); + normalizeEmitCPrintedAttrForCppEmission(ctx, *initialValue); if (normalized != *initialValue) global.getProperties().setInitialValue(normalized); return; From 7c5f8c91580541a4de9ef00eef520ce2f4e07a3f Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 01:02:41 +0800 Subject: [PATCH 14/51] fix: lower VPTO low-precision carriers for LLVM 21 --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 23 +++++++++---------- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 23 +++++++++---------- .../simt_lowlevel_ldst_policy_vpto_llvm.pto | 8 ++++++- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 1ed2263f99..d196f7b26e 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -67,16 +67,11 @@ static Type getElementTypeFromVectorLike(Type type); static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { - if (pto::isPTOHiFloat8Type(type)) - return Float8E4M3FNType::get(context); - if (isa(type)) - return IntegerType::get(context, 8); - if (isa(type)) + if (pto::isPTOHiFloat8Type(type) || isa(type) || + isa(type) || + pto::isPTOFloat8E4M3LikeType(type) || + pto::isPTOFloat8E5M2LikeType(type)) return IntegerType::get(context, 8); - if (pto::isPTOFloat8E4M3LikeType(type)) - return Float8E4M3Type::get(context); - if (pto::isPTOFloat8E5M2LikeType(type)) - return Float8E5M2Type::get(context); return {}; } @@ -101,7 +96,7 @@ static Type getLowpPayloadABIElementType(Type elementType, static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, Float8E4M3FNType::get(builder.getContext())); + {2}, builder.getI8Type()); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -9662,6 +9657,8 @@ static Value convertLdgCallResult(Location loc, Type valueType, if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { Value payload = rewriter.create(loc, rewriter.getI8Type(), callResult); + if (payload.getType() == convertedValueType) + return payload; return rewriter.create(loc, convertedValueType, payload); } return callResult; @@ -9795,8 +9792,10 @@ static Value convertStgValue(Location loc, Type valueType, Value value, } if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { - Value payload = - rewriter.create(loc, rewriter.getI8Type(), value); + Value payload = value; + if (payload.getType() != rewriter.getI8Type()) + payload = + rewriter.create(loc, rewriter.getI8Type(), value); return rewriter.create(loc, rewriter.getI32Type(), payload); } if (valueType.isBF16()) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index fbd0298c6a..9850b6e6c8 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -68,16 +68,11 @@ static Type getElementTypeFromVectorLike(Type type); static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { - if (pto::isPTOHiFloat8Type(type)) - return Float8E4M3FNType::get(context); - if (isa(type)) - return IntegerType::get(context, 8); - if (isa(type)) + if (pto::isPTOHiFloat8Type(type) || isa(type) || + isa(type) || + pto::isPTOFloat8E4M3LikeType(type) || + pto::isPTOFloat8E5M2LikeType(type)) return IntegerType::get(context, 8); - if (pto::isPTOFloat8E4M3LikeType(type)) - return Float8E4M3Type::get(context); - if (pto::isPTOFloat8E5M2LikeType(type)) - return Float8E5M2Type::get(context); return {}; } @@ -102,7 +97,7 @@ static Type getLowpPayloadABIElementType(Type elementType, static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, Float8E4M3FNType::get(builder.getContext())); + {2}, builder.getI8Type()); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -9607,6 +9602,8 @@ static Value convertLdgCallResult(Location loc, Type valueType, if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { Value payload = rewriter.create(loc, rewriter.getI8Type(), callResult); + if (payload.getType() == convertedValueType) + return payload; return rewriter.create(loc, convertedValueType, payload); } return callResult; @@ -9741,8 +9738,10 @@ static Value convertStgValue(Location loc, Type valueType, Value value, } if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { - Value payload = - rewriter.create(loc, rewriter.getI8Type(), value); + Value payload = value; + if (payload.getType() != rewriter.getI8Type()) + payload = + rewriter.create(loc, rewriter.getI8Type(), value); return rewriter.create(loc, rewriter.getI32Type(), payload); } if (valueType.isBF16()) diff --git a/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto b/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto index b2b01574fa..5e2c07e054 100644 --- a/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto +++ b/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto @@ -9,7 +9,7 @@ // RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @ldst_policy_kernel(%gm_i8: !pto.ptr, %gm_i16: !pto.ptr, %gm_i32: !pto.ptr, %gm_i64: !pto.ptr, %gm_f16: !pto.ptr, %gm_bf16: !pto.ptr, %gm_f32: !pto.ptr, %gm_f64: !pto.ptr, %dst_i32: !pto.ptr, %dst_i64: !pto.ptr) attributes {pto.aicore} { + func.func @ldst_policy_kernel(%gm_i8: !pto.ptr, %gm_i16: !pto.ptr, %gm_i32: !pto.ptr, %gm_i64: !pto.ptr, %gm_f16: !pto.ptr, %gm_bf16: !pto.ptr, %gm_f32: !pto.ptr, %gm_f64: !pto.ptr, %gm_f8: !pto.ptr, %gm_hif8: !pto.ptr, %dst_i32: !pto.ptr, %dst_i64: !pto.ptr) attributes {pto.aicore} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -28,6 +28,8 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> bf16 %load_f32 = pto.ldg %gm_f32[%c0] : !pto.ptr -> f32 %load_f64 = pto.ldg %gm_f64[%c0] : !pto.ptr -> f64 + %load_f8 = pto.ldg %gm_f8[%c0] l1cache(cache) l2cache(nmfv) : !pto.ptr -> f8E4M3FN + %load_hif8 = pto.ldg %gm_hif8[%c0] l1cache(uncache) l2cache(nmfv) : !pto.ptr -> !pto.hif8 pto.stg %i8, %gm_i8[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, i8 pto.stg %i16, %gm_i16[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, i16 pto.stg %load_i8, %gm_i8[%c2] l1cache(uncache) l2cache(nmfv) : !pto.ptr, i8 @@ -42,6 +44,8 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, f64 pto.stg %load_f16, %gm_f16[%c1] : !pto.ptr, f16 pto.stg %load_bf16, %gm_bf16[%c2] : !pto.ptr, bf16 + pto.stg %load_f8, %gm_f8[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, f8E4M3FN + pto.stg %load_hif8, %gm_hif8[%c2] l1cache(uncache) l2cache(nmfv) : !pto.ptr, !pto.hif8 pto.store %load_cache, %dst_i32[%c0] : !pto.ptr, i32 pto.store %load_uncache, %dst_i64[%c0] : !pto.ptr, i64 @@ -58,10 +62,12 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind Date: Fri, 12 Jun 2026 01:49:40 +0800 Subject: [PATCH 15/51] fix: guard EmitC integer attr reads --- .../Transforms/PTOMaterializeTileHandles.cpp | 13 ++- lib/PTO/Transforms/PTOToEmitC.cpp | 104 +++++++++++------- 2 files changed, 73 insertions(+), 44 deletions(-) diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 2c4b1d8de6..6c2a0e2d36 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -54,6 +54,11 @@ namespace { static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; +static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { + const APInt &value = attr.getValue(); + return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); +} + struct TileHandleMetadata { Value source; Value validRow; @@ -270,13 +275,13 @@ static bool getTilePointerStrides(TileBufConfigAttr configAttr, Type elemTy, if (auto blAttr = dyn_cast(configAttr.getBLayout())) blVal = static_cast(blAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); + blVal = static_cast(getIntegerAttrSignedValue(intAttr)); int32_t slVal = 0; if (auto slAttr = dyn_cast(configAttr.getSLayout())) slVal = static_cast(slAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); + slVal = static_cast(getIntegerAttrSignedValue(intAttr)); bool boxed = slVal != 0; int64_t innerRows = 1; @@ -284,7 +289,7 @@ static bool getTilePointerStrides(TileBufConfigAttr configAttr, Type elemTy, if (boxed) { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); if (elemBytes == 0) @@ -432,7 +437,7 @@ static Value materializeOffset(OpFoldResult ofr, OpBuilder &builder, Location loc) { if (auto attr = ofr.dyn_cast()) { if (auto intAttr = dyn_cast(attr)) - return makeI64Constant(builder, loc, intAttr.getInt()); + return makeI64Constant(builder, loc, getIntegerAttrSignedValue(intAttr)); return Value(); } return ensureI64(cast(ofr), builder, loc); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 4b004eb99d..42f6caa286 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -157,6 +157,18 @@ enum TNotifyMteDrainMask : unsigned { static constexpr llvm::StringLiteral kLastUseAttrName = "pto.last_use"; static constexpr llvm::StringLiteral kLastUseMarkerPrefix = "PTOAS__LAST_USE__"; +static int64_t getAPIntSignedValue(const APInt &value) { + return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); +} + +static uint64_t getAPIntUnsignedValue(const APInt &value) { + return value.getBitWidth() == 0 ? 0 : value.getZExtValue(); +} + +static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { + return getAPIntSignedValue(attr.getValue()); +} + static SmallVector collectTileOperandNumbers(Operation *op) { SmallVector tileOperandNumbers; for (OpOperand &operand : op->getOpOperands()) { @@ -845,7 +857,7 @@ static std::optional getEmitCTileTypeString(pto::TileBufType type) int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); return std::string("Tile<") + tileRoleToken(type.getMemorySpace(), elemTy, type.getConfigAttr()) + ", " + @@ -1157,14 +1169,15 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); if (failed(dirTok)) return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); + int32_t localSlotNum = + initOp.getLocalSlotNumAttr() + ? static_cast( + getIntegerAttrSignedValue(initOp.getLocalSlotNumAttr())) + : initOp.getSlotNum(); + return buildTPipeToken( + static_cast(getIntegerAttrSignedValue(initOp.getFlagBaseAttr())), + *dirTok, initOp.getSlotSize(), initOp.getSlotNum(), localSlotNum, + initOp.getNosplitAttr() && initOp.getNosplitAttr().getValue()); } if (auto initOp = dyn_cast(op)) { @@ -1174,10 +1187,10 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); if (failed(dirTok)) return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); + return buildTPipeToken( + static_cast(getIntegerAttrSignedValue(initOp.getFlagBaseAttr())), + *dirTok, initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && initOp.getNosplitAttr().getValue()); } return failure(); @@ -1348,7 +1361,8 @@ static InterCoreSyncCallDesc buildInterCoreSyncSetCall( if (targetArch == PTOArch::A3) { auto indexTy = emitc::OpaqueType::get(ctx, "int64_t"); Value eventVal = - makeEmitCIntConstant(rewriter, loc, indexTy, eventIdAttr.getInt()); + makeEmitCIntConstant(rewriter, loc, indexTy, + getIntegerAttrSignedValue(eventIdAttr)); Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); InterCoreSyncCallDesc desc; @@ -2827,11 +2841,6 @@ static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); } -static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { - const APInt &value = attr.getValue(); - return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); -} - static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, Attribute valueAttr) { auto opaqueTy = dyn_cast(targetType); @@ -2856,7 +2865,7 @@ static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, if (!first) os << ", "; first = false; - os << elem.getZExtValue(); + os << getAPIntUnsignedValue(elem); } os << "}"; os.flush(); @@ -3491,12 +3500,12 @@ struct SubviewToEmitCPattern : public OpConversionPattern { std::optional extractStaticInt(OpFoldResult ofr) const { if (auto attr = ofr.dyn_cast()) { if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); + return getIntegerAttrSignedValue(intAttr); } else { Value v = ofr.dyn_cast(); if (auto cOp = v.getDefiningOp()) { if (auto iAttr = dyn_cast(cOp.getValue())) - return iAttr.getInt(); + return getIntegerAttrSignedValue(iAttr); } else if (auto idxOp = v.getDefiningOp()) { return idxOp.value(); } @@ -4346,7 +4355,7 @@ static FailureOr buildAsyncScratchTileValue( int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); Type elemTy = memTy.getElementType(); pto::BLayout blayout = getTileBufBLayoutValue(configAttr); @@ -4556,7 +4565,7 @@ struct PointerCastConversion : public OpConversionPattern { if (!v) return false; if (auto cst = v.getDefiningOp()) { if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); + outVal = getIntegerAttrSignedValue(attr); return true; } } @@ -4623,7 +4632,8 @@ struct PointerCastConversion : public OpConversionPattern { std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + if (auto attr = dyn_cast(config.getSFractalSize())) + frVal = static_cast(getIntegerAttrSignedValue(attr)); int32_t padVal = 0; if (auto attr = dyn_cast(config.getPad())) @@ -6008,7 +6018,7 @@ struct PTOSyncSetToEmitC : public OpConversionPattern { Value eventIdDyn = adaptor.getEventIdDyn(); int64_t fftsMode = 2; if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); + fftsMode = getIntegerAttrSignedValue(fftsModeAttr); if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) return rewriter.notifyMatchFailure( @@ -6048,7 +6058,7 @@ struct PTOSyncSetToEmitC : public OpConversionPattern { emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); if (needsMirrorPlus16) { auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); + getIntegerAttrSignedValue(eventIdAttr) + 16); emitSet(Value{}, plus16, /*isDynamic=*/false); } } else { @@ -6368,7 +6378,7 @@ struct PTOHistogramToEmitC : public OpConversionPattern { int64_t byte = 1; auto byteAttr = op.getByteAttr(); if (byteAttr) - byte = byteAttr.getInt(); + byte = getIntegerAttrSignedValue(byteAttr); if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; if (byteAttr && byte != legacyByte) @@ -6889,14 +6899,28 @@ struct PTOBuildAsyncSessionToEmitC return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, std::to_string(value) + "u"); }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t syncId = op.getSyncIdAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getSyncIdAttr())) + : 0; uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + op.getBlockBytesAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getBlockBytesAttr())) + : 32 * 1024; uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + op.getCommBlockOffsetAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getCommBlockOffsetAttr())) + : 0; + uint64_t queueNum = op.getQueueNumAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getQueueNumAttr())) + : 1; uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() + ? static_cast( + getIntegerAttrSignedValue( + op.getChannelGroupIdxAttr())) : UINT32_MAX; Value syncIdVal = makeU32Const(syncId); @@ -7621,7 +7645,7 @@ static std::optional getStaticIndexLikeValue(Value value) { return cst.value(); if (auto cst = value.getDefiningOp()) { if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); + return getIntegerAttrSignedValue(intAttr); } return std::nullopt; } @@ -9445,7 +9469,7 @@ struct PTOGatherToEmitC : public OpConversionPattern { std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; int64_t offset = 0; if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); + offset = getIntegerAttrSignedValue(offsetAttr); auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); @@ -11727,7 +11751,7 @@ struct PTOXORSToEmitC : public OpConversionPattern { SmallVector templateArgVec; int64_t printFormat = 0; if (auto formatAttr = op.getPrintFormatAttr()) - printFormat = formatAttr.getInt(); + printFormat = getIntegerAttrSignedValue(formatAttr); if (printFormat != 0) { templateArgVec.push_back( emitc::OpaqueAttr::get(ctx, printFormatTok(printFormat))); @@ -11842,13 +11866,13 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (auto blAttr = dyn_cast(configAttr.getBLayout())) blVal = static_cast(blAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); + blVal = static_cast(getIntegerAttrSignedValue(intAttr)); int32_t slVal = 0; if (auto slAttr = dyn_cast(configAttr.getSLayout())) slVal = static_cast(slAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); + slVal = static_cast(getIntegerAttrSignedValue(intAttr)); bool boxed = slVal != 0; int64_t innerRows = 1; @@ -11856,7 +11880,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (boxed) { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); if (elemBytes == 0) @@ -12116,7 +12140,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); std::string padTok = "PadValue::Null"; if (auto padAttr = dyn_cast(configAttr.getPad())) { From db759283a92120ddc13984ed92886a65493309d9 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 03:16:38 +0800 Subject: [PATCH 16/51] fix: guard EmitC zero-width variable init --- tools/ptoas/ptoas.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 7f98d28122..cd812d3d8c 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1022,8 +1022,11 @@ static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { } static Attribute getDefaultEmitCVariableInitAttr(OpBuilder &builder, Type type) { - if (auto intTy = dyn_cast(type)) + if (auto intTy = dyn_cast(type)) { + if (intTy.getWidth() == 0) + return emitc::OpaqueAttr::get(builder.getContext(), "0"); return builder.getIntegerAttr(intTy, 0); + } if (isa(type)) return builder.getIndexAttr(0); if (auto floatTy = dyn_cast(type)) From 4875c070eda37a8a574c3d3edd691103d92e0162 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 05:35:17 +0800 Subject: [PATCH 17/51] fix: remove unused PTOToEmitC dataflow analysis --- lib/PTO/Transforms/PTOToEmitC.cpp | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 42f6caa286..f73de38a62 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -20,10 +20,6 @@ #include "PTO/IR/PTOSyncUtils.h" #include "PTO/Transforms/Passes.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Analysis/DataFlowFramework.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" @@ -13539,9 +13535,7 @@ struct CFSwitchToCondBr : public OpRewritePattern { static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, - DataFlowSolver &solver, PTOArch targetArch) { - (void)solver; patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -14102,14 +14096,8 @@ static AICORE inline void ptoas_auto_sync_tail( target.addLegalDialect(); target.addLegalOp(); - auto solver = std::make_unique(); - solver->load(); - solver->load(); - if (failed(solver->initializeAndRun(getOperation()))) - return signalPassFailure(); - RewritePatternSet patterns(ctx); - populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); + populatePTOToEmitCPatterns(patterns, typeConverter, ctx, targetArch); // 4. 执行转换 if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { From 0248b07d95ec678bd1858a1736a55345151076fb Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 07:05:11 +0800 Subject: [PATCH 18/51] fix: lower low-precision HIVM intrinsic ABI --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 78 +++++++++++-- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 78 +++++++++++-- .../vpto/low_precision_hivm_llvm_ir_abi.pto | 107 ++++++++++++++++++ 3 files changed, 241 insertions(+), 22 deletions(-) create mode 100644 test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index d196f7b26e..d2ef1414b8 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -619,6 +619,30 @@ static Type getPackedLowpScalarMemoryType(Type semanticType, return IntegerType::get(context, 16); } +static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, + MLIRContext *context) { + if (auto vregType = dyn_cast(semanticType)) { + if (!isLowpPayloadABIElementType(vregType.getElementType())) + return convertedType; + int64_t elementCount = vregType.getElementCount(); + if (elementCount <= 0 || elementCount % 4 != 0) + return {}; + return VectorType::get({elementCount / 4}, IntegerType::get(context, 32)); + } + + if (Type packedScalarType = + getPackedLowpScalarMemoryType(semanticType, context)) + return packedScalarType; + return convertedType; +} + +static Value bitcastToType(Location loc, Value value, Type targetType, + ConversionPatternRewriter &rewriter) { + if (!targetType || targetType == value.getType()) + return value; + return rewriter.create(loc, targetType, value); +} + static Type getScalarAccessGEPElementType(Type semanticType, Builder &builder) { if (Type memoryType = @@ -7686,10 +7710,24 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); + Type inputCallType = getLowpIntrinsicCarrierType( + op.getInput().getType(), adaptor.getInput().getType(), + rewriter.getContext()); + if (!inputCallType) + return rewriter.notifyMatchFailure(op, + "unsupported vcvt input carrier type"); + Type resultCallType = getLowpIntrinsicCarrierType( + op.getResult().getType(), resultType, rewriter.getContext()); + if (!resultCallType) + return rewriter.notifyMatchFailure(op, + "unsupported vcvt result carrier type"); + Value input = bitcastToType(op.getLoc(), adaptor.getInput(), inputCallType, + rewriter); + SmallVector callArgs; SmallVector argTypes; - callArgs.push_back(adaptor.getInput()); - argTypes.push_back(adaptor.getInput().getType()); + callArgs.push_back(input); + argTypes.push_back(input.getType()); callArgs.push_back(adaptor.getMask()); argTypes.push_back(adaptor.getMask().getType()); @@ -7737,12 +7775,15 @@ class LowerVcvtOpPattern final : public OpConversionPattern { argTypes.push_back(partValue.getType()); } - auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultCallType}); auto call = rewriter.create( - op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); + op.getLoc(), StringRef((*contract).intrinsic), + TypeRange{resultCallType}, callArgs); state.plannedDecls.push_back( PlannedDecl{std::string((*contract).intrinsic), funcType}); - rewriter.replaceOp(op, call.getResults()); + Value result = + bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); + rewriter.replaceOp(op, result); return success(); } @@ -9267,15 +9308,30 @@ class LowerConvertOpPattern final : public OpConversionPattern { Value saturation = getI32Constant( rewriter, op.getLoc(), static_cast(op.getSaturation())); + Type srcCallType = getLowpIntrinsicCarrierType( + op.getSrc().getType(), adaptor.getSrc().getType(), + rewriter.getContext()); + if (!srcCallType) + return rewriter.notifyMatchFailure( + op, "unsupported convert input carrier type"); + Type resultCallType = getLowpIntrinsicCarrierType( + op.getDst().getType(), resultType, rewriter.getContext()); + if (!resultCallType) + return rewriter.notifyMatchFailure( + op, "unsupported convert result carrier type"); + Value src = + bitcastToType(op.getLoc(), adaptor.getSrc(), srcCallType, rewriter); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSrc().getType(), rewriter.getI32Type(), - rewriter.getI32Type()}, - TypeRange{resultType}); + TypeRange{src.getType(), rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{resultCallType}); auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSrc(), rounding, saturation}); + op.getLoc(), *calleeName, TypeRange{resultCallType}, + ValueRange{src, rounding, saturation}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + Value result = + bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); + rewriter.replaceOp(op, result); return success(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 9850b6e6c8..61e5bd28c7 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -574,6 +574,30 @@ static Type getPackedLowpScalarMemoryType(Type semanticType, return IntegerType::get(context, 16); } +static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, + MLIRContext *context) { + if (auto vregType = dyn_cast(semanticType)) { + if (!isLowpPayloadABIElementType(vregType.getElementType())) + return convertedType; + int64_t elementCount = vregType.getElementCount(); + if (elementCount <= 0 || elementCount % 4 != 0) + return {}; + return VectorType::get({elementCount / 4}, IntegerType::get(context, 32)); + } + + if (Type packedScalarType = + getPackedLowpScalarMemoryType(semanticType, context)) + return packedScalarType; + return convertedType; +} + +static Value bitcastToType(Location loc, Value value, Type targetType, + ConversionPatternRewriter &rewriter) { + if (!targetType || targetType == value.getType()) + return value; + return rewriter.create(loc, targetType, value); +} + static Type getScalarAccessGEPElementType(Type semanticType, Builder &builder) { if (Type memoryType = @@ -7628,10 +7652,24 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); + Type inputCallType = getLowpIntrinsicCarrierType( + op.getInput().getType(), adaptor.getInput().getType(), + rewriter.getContext()); + if (!inputCallType) + return rewriter.notifyMatchFailure(op, + "unsupported vcvt input carrier type"); + Type resultCallType = getLowpIntrinsicCarrierType( + op.getResult().getType(), resultType, rewriter.getContext()); + if (!resultCallType) + return rewriter.notifyMatchFailure(op, + "unsupported vcvt result carrier type"); + Value input = bitcastToType(op.getLoc(), adaptor.getInput(), inputCallType, + rewriter); + SmallVector callArgs; SmallVector argTypes; - callArgs.push_back(adaptor.getInput()); - argTypes.push_back(adaptor.getInput().getType()); + callArgs.push_back(input); + argTypes.push_back(input.getType()); callArgs.push_back(adaptor.getMask()); argTypes.push_back(adaptor.getMask().getType()); @@ -7679,12 +7717,15 @@ class LowerVcvtOpPattern final : public OpConversionPattern { argTypes.push_back(partValue.getType()); } - auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultCallType}); auto call = rewriter.create( - op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); + op.getLoc(), StringRef((*contract).intrinsic), + TypeRange{resultCallType}, callArgs); state.plannedDecls.push_back( PlannedDecl{std::string((*contract).intrinsic), funcType}); - rewriter.replaceOp(op, call.getResults()); + Value result = + bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); + rewriter.replaceOp(op, result); return success(); } @@ -9211,15 +9252,30 @@ class LowerConvertOpPattern final : public OpConversionPattern { Value saturation = getI32Constant( rewriter, op.getLoc(), static_cast(op.getSaturation())); + Type srcCallType = getLowpIntrinsicCarrierType( + op.getSrc().getType(), adaptor.getSrc().getType(), + rewriter.getContext()); + if (!srcCallType) + return rewriter.notifyMatchFailure( + op, "unsupported convert input carrier type"); + Type resultCallType = getLowpIntrinsicCarrierType( + op.getDst().getType(), resultType, rewriter.getContext()); + if (!resultCallType) + return rewriter.notifyMatchFailure( + op, "unsupported convert result carrier type"); + Value src = + bitcastToType(op.getLoc(), adaptor.getSrc(), srcCallType, rewriter); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSrc().getType(), rewriter.getI32Type(), - rewriter.getI32Type()}, - TypeRange{resultType}); + TypeRange{src.getType(), rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{resultCallType}); auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSrc(), rounding, saturation}); + op.getLoc(), *calleeName, TypeRange{resultCallType}, + ValueRange{src, rounding, saturation}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + Value result = + bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); + rewriter.replaceOp(op, result); return success(); } diff --git a/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto b/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto new file mode 100644 index 0000000000..6463fe2903 --- /dev/null +++ b/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto @@ -0,0 +1,107 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: mkdir -p %T/fake-cann/bin %T/fake-cann/tools/bisheng_compiler/bin +// RUN: touch %T/fake-cann/bin/bisheng %T/fake-cann/bin/cce-ld %T/fake-cann/bin/ld.lld %T/fake-cann/tools/bisheng_compiler/bin/bisheng +// RUN: chmod +x %T/fake-cann/bin/bisheng %T/fake-cann/bin/cce-ld %T/fake-cann/bin/ld.lld %T/fake-cann/tools/bisheng_compiler/bin/bisheng +// RUN: env ASCEND_HOME_PATH=%T/fake-cann PATH=%T/fake-cann/bin:%T/fake-cann/tools/bisheng_compiler/bin:%PATH% ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o %t.ll +// RUN: FileCheck %s --check-prefix=LLVMIR < %t.ll + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @lowp_vector_kernel(%f32_seed: f32, + %bf16_seed: bf16, + %ub_f8: !pto.ptr, + %ub_hif8: !pto.ptr, + %ub_f4: !pto.ptr, + %dst_f32: !pto.ptr, + %dst_bf16: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask_b8 = pto.pset_b8 "PAT_ALL" : !pto.mask + %mask_b16 = pto.pset_b16 "PAT_ALL" : !pto.mask + %mask_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask + + %src_f32 = pto.vdup %f32_seed, %mask_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + %f8 = pto.vcvt %src_f32, %mask_b32 {rnd = "R", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + pto.vsts %f8, %ub_f8[%c0], %mask_b8 : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + %f8_loaded = pto.vlds %ub_f8[%c0] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> + %f8_back = pto.vcvt %f8_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %f8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + %hif8 = pto.vcvt %src_f32, %mask_b32 {rnd = "A", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> + pto.vsts %hif8, %ub_hif8[%c0], %mask_b8 : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask + %hif8_loaded = pto.vlds %ub_hif8[%c0] : !pto.ptr -> !pto.vreg<256x!pto.hif8> + %hif8_back = pto.vcvt %hif8_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %hif8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + %src_bf16 = pto.vdup %bf16_seed, %mask_b16 : bf16, !pto.mask -> !pto.vreg<128xbf16> + %f4 = pto.vcvt %src_bf16, %mask_b16 {rnd = "R", part = "P0"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<256x!pto.f4E1M2x2> + pto.vsts %f4, %ub_f4[%c0], %mask_b8 : !pto.vreg<256x!pto.f4E1M2x2>, !pto.ptr, !pto.mask + %f4_loaded = pto.vlds %ub_f4[%c0] : !pto.ptr -> !pto.vreg<256x!pto.f4E1M2x2> + %f4_back = pto.vcvt %f4_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.f4E1M2x2>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %f4_back, %dst_bf16[%c0], %mask_b16 : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + return + } + + func.func @lowp_simt_kernel(%gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %dim = arith.constant 1 : i32 + %threads = arith.constant 32 : i32 + %ub_i32 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_i64 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_hif8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + pto.store_vfsimt_info %dim, %dim, %threads : i32, i32, i32 + func.call @lowp_simt_body(%ub_i32, %ub_i64, %ub_hif8) : (!pto.ptr, !pto.ptr, !pto.ptr) -> () + pto.mte_ub_gm %ub_i32, %gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + return + } + + func.func @lowp_simt_body(%dst: !pto.ptr, %dst64: !pto.ptr, %hif8_dst: !pto.ptr) attributes {pto.simt_entry} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %f2 = arith.constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32> + %f8 = pto.convert %f2 round(r) nosat : vector<2xf32> -> vector<2xf8E4M3FN> + %hif8 = pto.convert %f2 round(a) nosat : vector<2xf32> -> !pto.hif8x2 + %f8_back = pto.convert %f8 round(r) nosat : vector<2xf8E4M3FN> -> vector<2xf32> + %hif8_back = pto.convert %hif8 round(a) nosat : !pto.hif8x2 -> vector<2xf32> + %f8_bits = llvm.bitcast %f8 : vector<2xf8E4M3FN> to i16 + %f8_back_bits = llvm.bitcast %f8_back : vector<2xf32> to i64 + %hif8_back_bits = llvm.bitcast %hif8_back : vector<2xf32> to i64 + %f8_bits_i32 = arith.extui %f8_bits : i16 to i32 + pto.store %f8_bits_i32, %dst[%c0] : !pto.ptr, i32 + pto.store %hif8, %hif8_dst[%c1] : !pto.ptr, !pto.hif8x2 + pto.store %f8_back_bits, %dst64[%c0] : !pto.ptr, i64 + pto.store %hif8_back_bits, %dst64[%c1] : !pto.ptr, i64 + return + } +} + +// LLVMIR: declare <64 x i32> @llvm.hivm.vcvtff.f322f8e4m3.x +// LLVMIR: declare void @llvm.hivm.vstsx1.v256f8e4m3(<256 x i8> +// LLVMIR: declare <256 x i8> @llvm.hivm.vldsx1.v256f8e4m3 +// LLVMIR: declare <64 x float> @llvm.hivm.vcvtff.f8e4m32f32.x(<64 x i32> +// LLVMIR: declare <64 x i32> @llvm.hivm.vcvtff.f322hif8.x +// LLVMIR: declare void @llvm.hivm.vstsx1.v256hif8(<256 x i8> +// LLVMIR: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x +// LLVMIR: declare <128 x bfloat> @llvm.hivm.vcvtff2.f4e1m2x22bf16.x(<64 x i32> +// LLVMIR: declare i16 @llvm.hivm.f32x2.to.f8e4m3x2 +// LLVMIR: declare i16 @llvm.hivm.f32x2.to.hif8x2 +// LLVMIR: declare <2 x float> @llvm.hivm.f8e4m3x2.to.f32x2(i16 +// LLVMIR: declare <2 x float> @llvm.hivm.hif8x2.to.f32x2(i16 +// LLVMIR: bitcast <64 x i32> {{%[0-9]+}} to <256 x i8> +// LLVMIR: bitcast <256 x i8> {{%[0-9]+}} to <64 x i32> +// LLVMIR: bitcast i16 {{%[0-9]+}} to <2 x i8> +// LLVMIR-NOT: declare <256 x f8e4m3> +// LLVMIR-NOT: declare <2 x f8e4m3> +// LLVMIR-NOT: declare <256 x i8> @llvm.hivm.vcvtff.f322f8e4m3.x +// LLVMIR-NOT: declare <2 x i8> @llvm.hivm.f32x2.to.f8e4m3x2 From 1b1897e358ff4cce66ccafbfa6628aa02f21a228 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 08:29:16 +0800 Subject: [PATCH 19/51] fix: gate known CANN beta VPTO SIM gaps --- .github/workflows/ci_sim.yml | 2 + .../cann-9.0.0-beta.1-sim.txt | 28 +++++ .../run_host_vpto_validation_parallel.sh | 100 +++++++++++++++--- 3 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index b20507ab89..03d81f9adf 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -508,6 +508,8 @@ jobs: PTOAS_BIN="${PTOAS_BIN}" \ DEVICE=SIM \ JOBS="${JOBS:-32}" \ + VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP=1 \ + VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE="${GITHUB_WORKSPACE}/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt" \ bash test/vpto/scripts/run_host_vpto_validation_parallel.sh - name: Run TileLang DSL CI diff --git a/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt b/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt new file mode 100644 index 0000000000..580d79c74e --- /dev/null +++ b/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# CANN 9.0.0 beta.1 Bisheng SIM cannot currently consume these LLVM 21 +# textual-IR cases. Low-precision vcvt cases require PTOAS to emit LLVM 21 +# carrier types, but beta.1 Bisheng verifies the legacy private intrinsic ABI +# while rejecting the private low-precision type spellings in textual IR. +# The SIMT cases fail in Bisheng's HiIPU backend after valid LLVM IR is handed +# to the compiler. Keep this list exact and remove entries as the CANN compiler +# catches up. +micro-op/conversion/vcvt-low-precision-roundtrip +micro-op/conversion/vcvt-low-precision-special +micro-op/simt/simt-atomic-packed-core +micro-op/simt/simt-atomic-s-core +micro-op/simt/simt-control-typed-core +micro-op/simt/simt-float-convert-core +micro-op/simt/simt-gm-memory-core +micro-op/simt/simt-ldst-policy-core +micro-op/simt/simt-memory-core +micro-op/simt/simt-packed-math-core +micro-op/simt/simt-scalar-core +micro-op/simt/simt-scalar-math-core +micro-op/simt/simt-store-tid diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh index 98be7d669d..9d4d67c3d5 100755 --- a/test/vpto/scripts/run_host_vpto_validation_parallel.sh +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -19,6 +19,8 @@ WORK_SPACE="${WORK_SPACE:-}" CASE_NAME="${CASE_NAME:-}" CASE_PREFIX="${CASE_PREFIX:-}" JOBS="${JOBS:-}" +VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP="${VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP:-0}" +VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE="${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE:-}" log() { echo "[$(date +'%F %T')] $*" @@ -69,6 +71,41 @@ fi [[ "${JOBS}" =~ ^[0-9]+$ ]] || die "JOBS must be a positive integer, got: ${JOBS}" [[ "${JOBS}" -ge 1 ]] || die "JOBS must be >= 1" +KNOWN_UNSUPPORTED_CASES=() + +trim() { + local text="$1" + text="${text#"${text%%[![:space:]]*}"}" + text="${text%"${text##*[![:space:]]}"}" + printf '%s' "${text}" +} + +load_known_unsupported_cases() { + [[ "${VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP}" == "1" ]] || return 0 + [[ -n "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" ]] || + die "VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE is required when known-unsupported skip is enabled" + [[ -f "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" ]] || + die "known-unsupported cases file not found: ${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" + + local raw line + while IFS= read -r raw || [[ -n "${raw}" ]]; do + line="$(trim "${raw%%#*}")" + [[ -n "${line}" ]] || continue + KNOWN_UNSUPPORTED_CASES+=("${line}") + done < "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" +} + +is_known_unsupported_case() { + local case_name="$1" + local known_case + for known_case in "${KNOWN_UNSUPPORTED_CASES[@]}"; do + if [[ "${known_case}" == "${case_name}" ]]; then + return 0 + fi + done + return 1 +} + mkdir -p "${WORK_SPACE}" WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" @@ -126,13 +163,32 @@ if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi -readarray -t CASES < <(discover_cases) -[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" +load_known_unsupported_cases + +DISCOVERED_CASES=() +while IFS= read -r case_name; do + DISCOVERED_CASES+=("${case_name}") +done < <(discover_cases) +[[ "${#DISCOVERED_CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" + +CASES=() +SKIPPED_CASES=() +for case_name in "${DISCOVERED_CASES[@]}"; do + if is_known_unsupported_case "${case_name}"; then + SKIPPED_CASES+=("${case_name}") + else + CASES+=("${case_name}") + fi +done + +[[ "${#CASES[@]}" -gt 0 || "${#SKIPPED_CASES[@]}" -gt 0 ]] || + die "no runnable or skipped cases found under ${CASES_ROOT}" : > "${SUMMARY_FILE}" : > "${RUNNER_LOG}" -declare -A PID_TO_CASE=() +RUNNING_PIDS=() +RUNNING_CASES=() launch_case() { local case_name="$1" @@ -143,12 +199,14 @@ launch_case() { ) & local pid=$! - PID_TO_CASE["${pid}"]="${case_name}" + RUNNING_PIDS+=("${pid}") + RUNNING_CASES+=("${case_name}") } reap_one() { - local pid="$1" - local case_name="${PID_TO_CASE[${pid}]}" + local index="$1" + local pid="${RUNNING_PIDS[${index}]}" + local case_name="${RUNNING_CASES[${index}]}" local result="FAIL" local detail="1" @@ -159,7 +217,8 @@ reap_one() { printf '%s\t%s\t%s\n' "${case_name}" "${result}" "${detail}" >> "${SUMMARY_FILE}" log "[${case_name}] ${result} (${detail})" | tee -a "${RUNNER_LOG}" - unset 'PID_TO_CASE['"${pid}"']' + unset 'RUNNING_PIDS['"${index}"']' + unset 'RUNNING_CASES['"${index}"']' } log "=== VPTO Host Validation Parallel ===" | tee -a "${RUNNER_LOG}" @@ -167,26 +226,34 @@ log "WORK_SPACE=${WORK_SPACE}" | tee -a "${RUNNER_LOG}" log "CASE_NAME=${CASE_NAME:-}" | tee -a "${RUNNER_LOG}" log "CASE_PREFIX=${CASE_PREFIX:-}" | tee -a "${RUNNER_LOG}" log "JOBS=${JOBS}" | tee -a "${RUNNER_LOG}" -log "TOTAL_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" +log "TOTAL_CASES=${#DISCOVERED_CASES[@]}" | tee -a "${RUNNER_LOG}" +log "RUNNABLE_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" +log "SKIPPED_CASES=${#SKIPPED_CASES[@]}" | tee -a "${RUNNER_LOG}" if [[ -n "${SIM_LIB_DIR:-}" ]]; then log "SIM_LIB_DIR=${SIM_LIB_DIR}" | tee -a "${RUNNER_LOG}" fi +for case_name in "${SKIPPED_CASES[@]}"; do + printf '%s\tSKIP\tknown-unsupported\n' "${case_name}" >> "${SUMMARY_FILE}" + log "[${case_name}] SKIP (known-unsupported)" | tee -a "${RUNNER_LOG}" +done + next_index=0 -while [[ "${next_index}" -lt "${#CASES[@]}" || "${#PID_TO_CASE[@]}" -gt 0 ]]; do - while [[ "${next_index}" -lt "${#CASES[@]}" && "${#PID_TO_CASE[@]}" -lt "${JOBS}" ]]; do +while [[ "${next_index}" -lt "${#CASES[@]}" || "${#RUNNING_PIDS[@]}" -gt 0 ]]; do + while [[ "${next_index}" -lt "${#CASES[@]}" && "${#RUNNING_PIDS[@]}" -lt "${JOBS}" ]]; do launch_case "${CASES[${next_index}]}" next_index="$((next_index + 1))" done - if [[ "${#PID_TO_CASE[@]}" -eq 0 ]]; then + if [[ "${#RUNNING_PIDS[@]}" -eq 0 ]]; then continue fi while true; do - for pid in "${!PID_TO_CASE[@]}"; do + for index in "${!RUNNING_PIDS[@]}"; do + pid="${RUNNING_PIDS[${index}]}" if ! kill -0 "${pid}" 2>/dev/null; then - reap_one "${pid}" + reap_one "${index}" break 2 fi done @@ -195,13 +262,14 @@ while [[ "${next_index}" -lt "${#CASES[@]}" || "${#PID_TO_CASE[@]}" -gt 0 ]]; do done pass_count="$(awk -F '\t' '$2 == "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" -fail_count="$(awk -F '\t' '$2 != "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" +skip_count="$(awk -F '\t' '$2 == "SKIP" {count++} END {print count + 0}' "${SUMMARY_FILE}")" +fail_count="$(awk -F '\t' '$2 == "FAIL" {count++} END {print count + 0}' "${SUMMARY_FILE}")" -log "PASS=${pass_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" +log "PASS=${pass_count} SKIP=${skip_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" log "summary: ${SUMMARY_FILE}" | tee -a "${RUNNER_LOG}" if [[ "${fail_count}" -ne 0 ]]; then die "parallel validation finished with ${fail_count} failing case(s)" fi -log "All ${pass_count} case(s) passed" | tee -a "${RUNNER_LOG}" +log "All ${pass_count} runnable case(s) passed; ${skip_count} known-unsupported case(s) skipped" | tee -a "${RUNNER_LOG}" From 950a20761237e644ff60e806ded7db1d2d8c5756 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 16:42:18 +0800 Subject: [PATCH 20/51] build: use LLVM21 VPTO dependency branch --- .github/workflows/build_wheel.yml | 12 ++++++------ .github/workflows/build_wheel_mac.yml | 12 ++++++------ .github/workflows/ci.yml | 10 +++++----- .github/workflows/ci_sim.yml | 4 ++-- README.md | 12 ++++++------ README_en.md | 12 ++++++------ ReleaseNotes.md | 2 +- docker/Dockerfile | 5 +++-- docs/build_with_installed_llvm.md | 4 +++- docs/designs/ci-board-validation-guide.md | 4 ++-- lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp | 4 ++-- 11 files changed, 42 insertions(+), 39 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index f00d7321b1..c8f572e749 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -24,9 +24,9 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/llvm/llvm-project.git - LLVM_TAG: llvmorg-21.1.8 - LLVM_CACHE_FLAVOR: llvm21-release-hardening-v1 + LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REF: feature-vpto-llvm21 + LLVM_CACHE_FLAVOR: llvm21-vpto-release-hardening-v1 jobs: build_wheel: @@ -98,9 +98,9 @@ jobs: - name: Resolve LLVM source SHA id: llvm-source run: | - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" if [ -z "${LLVM_SOURCE_SHA}" ]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 exit 1 fi echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" @@ -121,7 +121,7 @@ jobs: git remote add origin "${LLVM_REPO}" fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_TAG}" + git fetch --depth 1 origin "${LLVM_REF}" git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 2a8106f28c..d54df8dfc3 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -23,9 +23,9 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/llvm/llvm-project.git - LLVM_TAG: llvmorg-21.1.8 - LLVM_CACHE_FLAVOR: llvm21-release-v1 + LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REF: feature-vpto-llvm21 + LLVM_CACHE_FLAVOR: llvm21-vpto-release-v1 jobs: build_wheel: @@ -101,9 +101,9 @@ jobs: - name: Resolve LLVM source SHA id: llvm-source run: | - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" if [ -z "${LLVM_SOURCE_SHA}" ]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 exit 1 fi echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" @@ -124,7 +124,7 @@ jobs: git remote add origin "${LLVM_REPO}" fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_TAG}" + git fetch --depth 1 origin "${LLVM_REF}" git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7edcceb531..521ed88c29 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -97,8 +97,8 @@ jobs: runs-on: ubuntu-22.04 env: PTOAS_CLANG_MAJOR: "15" - LLVM_REPO: https://github.com/llvm/llvm-project.git - LLVM_TAG: llvmorg-21.1.8 + LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REF: feature-vpto-llvm21 LLVM_BUILD_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert PTO_BUILD_DIR: ${{ github.workspace }}/build-assert @@ -159,9 +159,9 @@ jobs: shell: bash run: | set -euo pipefail - LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/tags/${LLVM_TAG}" | awk '{print $1}')" + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" [[ -n "${LLVM_SOURCE_SHA}" ]] || { - echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_TAG}" >&2 + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 exit 1 } echo "LLVM_SOURCE_SHA=${LLVM_SOURCE_SHA}" >> "${GITHUB_ENV}" @@ -187,7 +187,7 @@ jobs: fi git remote set-url origin "${LLVM_REPO}" - git fetch --depth 1 origin "${LLVM_TAG}" + git fetch --depth 1 origin "${LLVM_REF}" git checkout --force FETCH_HEAD - name: Build LLVM/MLIR (only if cache miss) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 03d81f9adf..2abb43821a 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -107,8 +107,8 @@ jobs: needs.detect-vpto-sim-changes.outputs.should_run == 'true' }} env: - LLVM_REPO: https://github.com/llvm/llvm-project.git - LLVM_TAG: llvmorg-21.1.8 + LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REF: feature-vpto-llvm21 PTO_INSTALL_DIR: ${{ github.workspace }}/install VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci diff --git a/README.md b/README.md index 467bf94372..41d6e47386 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## 1. 项目简介 (Introduction) -**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR (llvmorg-21.1.8)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 +**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR LLVM21 VPTO 分支 (`TaoTao-real/llvm-project:feature-vpto-llvm21`)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 作为连接上层 AI 框架与底层各类NPU/GPGPU/CPU硬件,`ptoas` 采用 **Out-of-Tree** 架构构建,提供了完整的 C++ 与 Python 接口,主要职责包括: @@ -37,7 +37,7 @@ PTOAS/ ## 3. 构建指南 (Build Instructions) -⚠️ **重要提示**:本项目严格依赖 **LLVM llvmorg-21.1.8** 版本。 +⚠️ **重要提示**:本项目严格依赖 **LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21`**。 ### 3.0 环境变量配置 (Configuration) @@ -84,16 +84,16 @@ python3 -m pip install 'pybind11<3' nanobind numpy ### 3.2 第一步:构建 LLVM/MLIR (Dependency) -我们需要下载 LLVM 源码,切换到 `llvmorg-21.1.8` 标签,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 +我们需要下载 VPTO 适配后的 LLVM 源码,切换到 `feature-vpto-llvm21` 分支,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 ```bash # 1. 下载 LLVM 源码 cd $WORKSPACE_DIR -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/TaoTao-real/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [关键] 切换到 llvmorg-21.1.8 -git checkout llvmorg-21.1.8 +# 2. [关键] 切换到 VPTO 适配分支 +git checkout feature-vpto-llvm21 # 3. 配置 CMake (构建动态库并启用 Python 绑定) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ diff --git a/README_en.md b/README_en.md index 5528b1f688..c1a75c6a77 100644 --- a/README_en.md +++ b/README_en.md @@ -2,7 +2,7 @@ ## 1. Introduction -**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-21.1.8)** *(Commit 2078da43e25a4623cab2d0d60decddf709aaea28)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). +**ptoas** is a specialized compiler toolchain built on top of the **LLVM21 VPTO branch (`TaoTao-real/llvm-project:feature-vpto-llvm21`)**, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: @@ -36,7 +36,7 @@ PTOAS/ ## 3. Build Instructions -⚠️ **Important**: This project strictly requires **LLVM llvmorg-21.1.8**. +⚠️ **Important**: This project strictly requires the **LLVM21 VPTO branch `TaoTao-real/llvm-project:feature-vpto-llvm21`**. ### 3.0 Environment Variable Configuration @@ -79,16 +79,16 @@ python3 -m pip install "pybind11<3" nanobind numpy ### 3.2 Step 1: Build LLVM/MLIR (Dependency) -Download the LLVM source, check out the `llvmorg-21.1.8` tag, and build with **shared libraries** to ensure correct linking for Python bindings. +Download the VPTO-adapted LLVM source, check out the `feature-vpto-llvm21` branch, and build with **shared libraries** to ensure correct linking for Python bindings. ```bash # 1. Clone LLVM cd $WORKSPACE_DIR -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/TaoTao-real/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [Critical] Check out llvmorg-21.1.8 -git checkout llvmorg-21.1.8 +# 2. [Critical] Check out the VPTO adaptation branch +git checkout feature-vpto-llvm21 # 3. Configure CMake (build shared libs with Python bindings enabled) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ diff --git a/ReleaseNotes.md b/ReleaseNotes.md index 8aa8592e7c..e584f1ff5e 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -8,7 +8,7 @@ - PTOAS 首次发布 ## 概述 -PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR llvmorg-21.1.8 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 +PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21` 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 PTOAS很快将集成到以下框架中,敬请期待 - PyPTO diff --git a/docker/Dockerfile b/docker/Dockerfile index 7ecdebcf20..9789c57edc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,8 @@ ARG ARCH # NOTE: change $PY_VER for different Python versions (3.8 - 3.14 available) ARG PY_VER=cp311-cp311 -ARG LLVM_TAG=llvmorg-21.1.8 +ARG LLVM_REF=feature-vpto-llvm21 +ARG LLVM_REPO=https://github.com/TaoTao-real/llvm-project.git ## -- usually no need to change below -- @@ -34,7 +35,7 @@ ENV PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-release # build LLVM WORKDIR $WORKSPACE_DIR -RUN git clone --depth 1 --branch ${LLVM_TAG} https://github.com/llvm/llvm-project.git +RUN git clone --depth 1 --branch ${LLVM_REF} ${LLVM_REPO} llvm-project WORKDIR $LLVM_SOURCE_DIR RUN cmake -C /tmp/LinuxHardeningCache.cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md index 316b46a8a5..8e3f14575d 100644 --- a/docs/build_with_installed_llvm.md +++ b/docs/build_with_installed_llvm.md @@ -2,7 +2,7 @@ 本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: -- LLVM/MLIR `21.1.8` 已经构建并安装完成。 +- LLVM/MLIR LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21` 已经构建并安装完成。 - LLVM 安装路径固定为 `/opt/llvm`。 - `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 @@ -66,6 +66,8 @@ README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 21.1.8 ``` +实际源码基线应来自 `https://github.com/TaoTao-real/llvm-project.git` 的 `feature-vpto-llvm21` 分支。 + ## 3.3 第二步:构建 ptoas 这里沿用 README 第 3.3 节的流程,但 `LLVM_DIR` 和 `MLIR_DIR` 需要改为 diff --git a/docs/designs/ci-board-validation-guide.md b/docs/designs/ci-board-validation-guide.md index 969cfc6055..0678229827 100644 --- a/docs/designs/ci-board-validation-guide.md +++ b/docs/designs/ci-board-validation-guide.md @@ -81,9 +81,9 @@ ### 3.1 构建 LLVM/MLIR ```bash -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/TaoTao-real/llvm-project.git cd llvm-project -git checkout llvmorg-21.1.8 +git checkout feature-vpto-llvm21 cmake -G Ninja -S llvm -B llvm/build-shared \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ diff --git a/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp index ba12ff60b7..a21ebfd316 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp @@ -32,7 +32,7 @@ int64_t EventIdSolver::getEventIdsNum(bool dontCalcEventIds) { calcEventIds(); } assert(!needRecalculateEventIds); - llvm::SmallDenseSet usedEventIds; + llvm::DenseSet usedEventIds; for (auto &node : nodes) { auto &eventIds = node->getEventIds(); assert(!eventIds.empty()); @@ -215,7 +215,7 @@ void EventIdSolver::addConflicts( llvm::SmallVector EventIdSolver::getAdjNodesUsedEventIds(EventIdNode *node) { - llvm::SmallDenseSet usedEventIds; + llvm::DenseSet usedEventIds; for (auto [otherNode, frq] : adjList[node]) { assert(frq > 0); auto &otherEventIds = otherNode->getEventIds(); From ccdab606beff1a7229d92b9f07e4dab23c1706ef Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 12 Jun 2026 20:17:40 +0800 Subject: [PATCH 21/51] test: use configured Python for ptobc maskpattern check --- tools/ptobc/tests/CMakeLists.txt | 1 + .../ptobc/tests/tscatter_maskpattern_v0_encode.sh | 14 +++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/ptobc/tests/CMakeLists.txt b/tools/ptobc/tests/CMakeLists.txt index 2a2f78e3a4..0489351c08 100644 --- a/tools/ptobc/tests/CMakeLists.txt +++ b/tools/ptobc/tests/CMakeLists.txt @@ -117,6 +117,7 @@ add_test(NAME ptobc_tscatter_maskpattern_v0_encode COMMAND ${CMAKE_COMMAND} -E env PTOBC_BIN=$ TESTDATA_DIR=${PTObc_TESTDATA_DIR} + PYTHON_EXECUTABLE=${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/tscatter_maskpattern_v0_encode.sh ) diff --git a/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh b/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh index 482cd74eef..c85b1dadee 100755 --- a/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh +++ b/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh @@ -35,7 +35,19 @@ grep -F "pto.tscatter ins(" "${ROUNDTRIP}" >/dev/null grep -F "{maskPattern = #pto.mask_pattern}" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.tscatter ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null -python - <<'PY' "${BC}" +PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-} +if [[ -z "${PYTHON_EXECUTABLE}" ]]; then + if command -v python3 >/dev/null 2>&1; then + PYTHON_EXECUTABLE=python3 + elif command -v python >/dev/null 2>&1; then + PYTHON_EXECUTABLE=python + else + echo "error: neither PYTHON_EXECUTABLE nor python3/python is available" >&2 + exit 2 + fi +fi + +"${PYTHON_EXECUTABLE}" - <<'PY' "${BC}" from pathlib import Path import sys From d8f0bacc36cc1f9d2baab8ea68718f040f6bbb86 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sat, 13 Jun 2026 01:01:44 +0800 Subject: [PATCH 22/51] fix: emit native lowp vcvt intrinsic types --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 36 ++++++++++++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 36 ++++++++++++- .../vpto/low_precision_hivm_llvm_ir_abi.pto | 50 +++++++++++++------ 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index d2ef1414b8..e21345ba85 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -578,6 +578,38 @@ static Type getLowpPayloadCarrierType(Type vectorLikeType, return VectorType::get({*lanes}, abi->llvmElementType); } +static Type getLowpVcvtIntrinsicElementType(Type elementType, + MLIRContext *context) { + if (pto::isPTOHiFloat8Type(elementType)) + return LLVM::LLVMHiFloat8Type::get(context); + if (isa(elementType)) + return LLVM::LLVMFloat4E1M2x2Type::get(context); + if (isa(elementType)) + return LLVM::LLVMFloat4E2M1x2Type::get(context); + if (pto::isPTOFloat8E4M3LikeType(elementType)) + return LLVM::LLVMFloat8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(elementType)) + return LLVM::LLVMFloat8E5M2Type::get(context); + return {}; +} + +static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, + MLIRContext *context); + +static Type getVcvtIntrinsicValueType(Type semanticType, Type convertedType, + MLIRContext *context) { + Type elementType = getElementTypeFromVectorLike(semanticType); + Type lowpElementType = + getLowpVcvtIntrinsicElementType(elementType, context); + if (!lowpElementType) + return getLowpIntrinsicCarrierType(semanticType, convertedType, context); + + auto lanes = getElementCountFromVectorLike(semanticType); + if (!lanes) + return {}; + return VectorType::get({*lanes}, lowpElementType); +} + static Type getPayloadABIType(Type semanticType, Type convertedType, MLIRContext *context) { if (Type carrierType = getLowpPayloadCarrierType(semanticType, context)) @@ -7710,13 +7742,13 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); - Type inputCallType = getLowpIntrinsicCarrierType( + Type inputCallType = getVcvtIntrinsicValueType( op.getInput().getType(), adaptor.getInput().getType(), rewriter.getContext()); if (!inputCallType) return rewriter.notifyMatchFailure(op, "unsupported vcvt input carrier type"); - Type resultCallType = getLowpIntrinsicCarrierType( + Type resultCallType = getVcvtIntrinsicValueType( op.getResult().getType(), resultType, rewriter.getContext()); if (!resultCallType) return rewriter.notifyMatchFailure(op, diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 61e5bd28c7..04d37d466f 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -533,6 +533,38 @@ static Type getLowpPayloadCarrierType(Type vectorLikeType, return VectorType::get({*lanes}, abi->llvmElementType); } +static Type getLowpVcvtIntrinsicElementType(Type elementType, + MLIRContext *context) { + if (pto::isPTOHiFloat8Type(elementType)) + return LLVM::LLVMHiFloat8Type::get(context); + if (isa(elementType)) + return LLVM::LLVMFloat4E1M2x2Type::get(context); + if (isa(elementType)) + return LLVM::LLVMFloat4E2M1x2Type::get(context); + if (pto::isPTOFloat8E4M3LikeType(elementType)) + return LLVM::LLVMFloat8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(elementType)) + return LLVM::LLVMFloat8E5M2Type::get(context); + return {}; +} + +static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, + MLIRContext *context); + +static Type getVcvtIntrinsicValueType(Type semanticType, Type convertedType, + MLIRContext *context) { + Type elementType = getElementTypeFromVectorLike(semanticType); + Type lowpElementType = + getLowpVcvtIntrinsicElementType(elementType, context); + if (!lowpElementType) + return getLowpIntrinsicCarrierType(semanticType, convertedType, context); + + auto lanes = getElementCountFromVectorLike(semanticType); + if (!lanes) + return {}; + return VectorType::get({*lanes}, lowpElementType); +} + static Type getPayloadABIType(Type semanticType, Type convertedType, MLIRContext *context) { if (Type carrierType = getLowpPayloadCarrierType(semanticType, context)) @@ -7652,13 +7684,13 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); - Type inputCallType = getLowpIntrinsicCarrierType( + Type inputCallType = getVcvtIntrinsicValueType( op.getInput().getType(), adaptor.getInput().getType(), rewriter.getContext()); if (!inputCallType) return rewriter.notifyMatchFailure(op, "unsupported vcvt input carrier type"); - Type resultCallType = getLowpIntrinsicCarrierType( + Type resultCallType = getVcvtIntrinsicValueType( op.getResult().getType(), resultType, rewriter.getContext()); if (!resultCallType) return rewriter.notifyMatchFailure(op, diff --git a/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto b/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto index 6463fe2903..6aeca05841 100644 --- a/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto +++ b/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto @@ -14,10 +14,13 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { func.func @lowp_vector_kernel(%f32_seed: f32, + %f16_seed: f16, %bf16_seed: bf16, %ub_f8: !pto.ptr, + %ub_f8e5: !pto.ptr, %ub_hif8: !pto.ptr, %ub_f4: !pto.ptr, + %ub_f4e2: !pto.ptr, %dst_f32: !pto.ptr, %dst_bf16: !pto.ptr) attributes {pto.kernel} { %c0 = arith.constant 0 : index @@ -33,18 +36,28 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.mask -> !pto.vreg<64xf32> pto.vsts %f8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %f8e5 = pto.vcvt %src_f32, %mask_b32 {rnd = "R", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E5M2> + pto.vsts %f8e5, %ub_f8e5[%c0], %mask_b8 : !pto.vreg<256xf8E5M2>, !pto.ptr, !pto.mask + %hif8 = pto.vcvt %src_f32, %mask_b32 {rnd = "A", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> pto.vsts %hif8, %ub_hif8[%c0], %mask_b8 : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask %hif8_loaded = pto.vlds %ub_hif8[%c0] : !pto.ptr -> !pto.vreg<256x!pto.hif8> %hif8_back = pto.vcvt %hif8_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %hif8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %src_f16 = pto.vdup %f16_seed, %mask_b16 : f16, !pto.mask -> !pto.vreg<128xf16> + %hif8_from_f16 = pto.vcvt %src_f16, %mask_b16 {rnd = "A", sat = "NOSAT", part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256x!pto.hif8> + pto.vsts %hif8_from_f16, %ub_hif8[%c0], %mask_b8 : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask + %src_bf16 = pto.vdup %bf16_seed, %mask_b16 : bf16, !pto.mask -> !pto.vreg<128xbf16> %f4 = pto.vcvt %src_bf16, %mask_b16 {rnd = "R", part = "P0"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<256x!pto.f4E1M2x2> pto.vsts %f4, %ub_f4[%c0], %mask_b8 : !pto.vreg<256x!pto.f4E1M2x2>, !pto.ptr, !pto.mask %f4_loaded = pto.vlds %ub_f4[%c0] : !pto.ptr -> !pto.vreg<256x!pto.f4E1M2x2> %f4_back = pto.vcvt %f4_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.f4E1M2x2>, !pto.mask -> !pto.vreg<128xbf16> pto.vsts %f4_back, %dst_bf16[%c0], %mask_b16 : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + + %f4e2 = pto.vcvt %src_bf16, %mask_b16 {rnd = "R", part = "P0"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<256x!pto.f4E2M1x2> + pto.vsts %f4e2, %ub_f4e2[%c0], %mask_b8 : !pto.vreg<256x!pto.f4E2M1x2>, !pto.ptr, !pto.mask } return } @@ -86,22 +99,31 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind @llvm.hivm.vcvtff.f322f8e4m3.x -// LLVMIR: declare void @llvm.hivm.vstsx1.v256f8e4m3(<256 x i8> -// LLVMIR: declare <256 x i8> @llvm.hivm.vldsx1.v256f8e4m3 -// LLVMIR: declare <64 x float> @llvm.hivm.vcvtff.f8e4m32f32.x(<64 x i32> -// LLVMIR: declare <64 x i32> @llvm.hivm.vcvtff.f322hif8.x -// LLVMIR: declare void @llvm.hivm.vstsx1.v256hif8(<256 x i8> -// LLVMIR: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x -// LLVMIR: declare <128 x bfloat> @llvm.hivm.vcvtff2.f4e1m2x22bf16.x(<64 x i32> -// LLVMIR: declare i16 @llvm.hivm.f32x2.to.f8e4m3x2 -// LLVMIR: declare i16 @llvm.hivm.f32x2.to.hif8x2 -// LLVMIR: declare <2 x float> @llvm.hivm.f8e4m3x2.to.f32x2(i16 -// LLVMIR: declare <2 x float> @llvm.hivm.hif8x2.to.f32x2(i16 -// LLVMIR: bitcast <64 x i32> {{%[0-9]+}} to <256 x i8> -// LLVMIR: bitcast <256 x i8> {{%[0-9]+}} to <64 x i32> +// LLVMIR-DAG: declare <256 x float8e4m3> @llvm.hivm.vcvtff.f322f8e4m3.x +// LLVMIR-DAG: declare void @llvm.hivm.vstsx1.v256f8e4m3(<256 x i8> +// LLVMIR-DAG: declare <256 x i8> @llvm.hivm.vldsx1.v256f8e4m3 +// LLVMIR-DAG: declare <64 x float> @llvm.hivm.vcvtff.f8e4m32f32.x(<256 x float8e4m3> +// LLVMIR-DAG: declare <256 x float8e5m2> @llvm.hivm.vcvtff.f322f8e5m2.x +// LLVMIR-DAG: declare <256 x hifloat8> @llvm.hivm.vcvtff.f322hif8.x +// LLVMIR-DAG: declare <256 x hifloat8> @llvm.hivm.vcvtff.f162hif8.x +// LLVMIR-DAG: declare void @llvm.hivm.vstsx1.v256hif8(<256 x i8> +// LLVMIR-DAG: declare <256 x float4e1m2x2> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x +// LLVMIR-DAG: declare <256 x float4e2m1x2> @llvm.hivm.vcvtff2.bf162f4e2m1x2.x +// LLVMIR-DAG: declare <128 x bfloat> @llvm.hivm.vcvtff2.f4e1m2x22bf16.x(<256 x float4e1m2x2> +// LLVMIR-DAG: declare i16 @llvm.hivm.f32x2.to.f8e4m3x2 +// LLVMIR-DAG: declare i16 @llvm.hivm.f32x2.to.hif8x2 +// LLVMIR-DAG: declare <2 x float> @llvm.hivm.f8e4m3x2.to.f32x2(i16 +// LLVMIR-DAG: declare <2 x float> @llvm.hivm.hif8x2.to.f32x2(i16 +// LLVMIR: bitcast <256 x float8e4m3> {{%[0-9]+}} to <256 x i8> +// LLVMIR: bitcast <256 x i8> {{%[0-9]+}} to <256 x float8e4m3> // LLVMIR: bitcast i16 {{%[0-9]+}} to <2 x i8> // LLVMIR-NOT: declare <256 x f8e4m3> // LLVMIR-NOT: declare <2 x f8e4m3> +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322f8e4m3.x +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322f8e5m2.x +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322hif8.x +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f162hif8.x +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x +// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e2m1x2.x // LLVMIR-NOT: declare <256 x i8> @llvm.hivm.vcvtff.f322f8e4m3.x // LLVMIR-NOT: declare <2 x i8> @llvm.hivm.f32x2.to.f8e4m3x2 From 1789679409595fc022b7ad1c704822d90f8c7e14 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sat, 13 Jun 2026 03:12:19 +0800 Subject: [PATCH 23/51] ci: extend vpto sim validation timeout --- .github/workflows/ci_sim.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 2abb43821a..a857884d3d 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -98,7 +98,7 @@ jobs: vpto-sim-validation: needs: detect-vpto-sim-changes runs-on: [self-hosted, Linux, X64, label-1] - timeout-minutes: 120 + timeout-minutes: 180 concurrency: group: vpto-sim-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true From e750676d91fc22928feb7908cd48481be2ae0a8d Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sat, 13 Jun 2026 07:10:53 +0800 Subject: [PATCH 24/51] ci: extend vpto sim timeout for llvm21 --- .github/workflows/ci_sim.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index a857884d3d..4b9e4b7ee8 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -98,7 +98,7 @@ jobs: vpto-sim-validation: needs: detect-vpto-sim-changes runs-on: [self-hosted, Linux, X64, label-1] - timeout-minutes: 180 + timeout-minutes: 300 concurrency: group: vpto-sim-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true From d4d2ee2a872307f05e722c8d0a307d38352a756e Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 16 Jun 2026 14:39:29 +0800 Subject: [PATCH 25/51] fix: restore PTO entry compatibility for LLVM 21 --- lib/PTO/Transforms/ExpandTileOp.cpp | 5 +++-- lib/PTO/Transforms/InferPTOLayout.cpp | 5 +++-- lib/PTO/Transforms/PTOMaterializeTileHandles.cpp | 3 ++- lib/PTO/Transforms/PTOToEmitC.cpp | 15 +++++++++------ 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 2dac9d936e..fabca3d29c 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -522,12 +522,13 @@ static bool getStaticIntFromValue(Value value, int64_t &out) { } static int64_t getStaticIntOrDynamic(OpFoldResult ofr) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) return intAttr.getInt(); return ShapedType::kDynamic; } - auto value = llvm::cast(ofr); + Value value = cast(ofr); int64_t result = ShapedType::kDynamic; if (getStaticIntFromValue(value, result)) return result; diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index a485272a17..b25e37d7a3 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -64,12 +64,13 @@ static std::optional getConstInt(Value v) { } static std::optional getConstInt(OpFoldResult ofr) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto ia = dyn_cast(attr)) return ia.getInt(); return std::nullopt; } - return getConstInt(ofr.dyn_cast()); + return getConstInt(cast(ofr)); } static unsigned elemByteSize(Type ty) { diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 6c2a0e2d36..d57b3a7740 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -435,7 +435,8 @@ static Value ensureI64(Value value, OpBuilder &builder, Location loc) { static Value materializeOffset(OpFoldResult ofr, OpBuilder &builder, Location loc) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) return makeI64Constant(builder, loc, getIntegerAttrSignedValue(intAttr)); return Value(); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f73de38a62..1213fabacf 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -3494,11 +3494,12 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 std::optional extractStaticInt(OpFoldResult ofr) const { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) return getIntegerAttrSignedValue(intAttr); } else { - Value v = ofr.dyn_cast(); + Value v = cast(ofr); if (auto cOp = v.getDefiningOp()) { if (auto iAttr = dyn_cast(cOp.getValue())) return getIntegerAttrSignedValue(iAttr); @@ -3569,13 +3570,15 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { - if (auto v = ofr.dyn_cast()) { + if (isa(ofr)) { + Value v = cast(ofr); Value rv = rewriter.getRemappedValue(v); return asIndex(rv); } - if (auto attr = ofr.dyn_cast()) { - if (auto ia = dyn_cast(attr)) - return mkIndex(getIntegerAttrSignedValue(ia)); + if (isa(ofr)) { + Attribute attr = cast(ofr); + if (auto ia = dyn_cast(attr)) + return mkIndex(getIntegerAttrSignedValue(ia)); } return mkIndex(0); }; From defcf822caf0d39534dc0fd0ddc7e6328fc8519e Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 16 Jun 2026 17:40:34 +0800 Subject: [PATCH 26/51] test: restore Qwen3DecodeA5 left tile layouts --- .../Qwen3DecodeA5/down_proj_residual.pto | 36 +++++----- .../Qwen3DecodeA5/out_proj_residual.pto | 36 +++++----- .../Qwen3DecodeA5/qwen3_decode_incore_1.pto | 36 +++++----- .../Qwen3DecodeA5/qwen3_decode_incore_10.pto | 48 ++++++------- .../Qwen3DecodeA5/qwen3_decode_incore_11.pto | 48 ++++++------- .../Qwen3DecodeA5/qwen3_decode_incore_2.pto | 72 +++++++++---------- .../Qwen3DecodeA5/qwen3_decode_incore_4.pto | 24 +++---- .../Qwen3DecodeA5/qwen3_decode_incore_6.pto | 24 +++---- 8 files changed, 162 insertions(+), 162 deletions(-) diff --git a/test/samples/Qwen3DecodeA5/down_proj_residual.pto b/test/samples/Qwen3DecodeA5/down_proj_residual.pto index 9e55077eba..b81f7269ba 100644 --- a/test/samples/Qwen3DecodeA5/down_proj_residual.pto +++ b/test/samples/Qwen3DecodeA5/down_proj_residual.pto @@ -47,44 +47,44 @@ module attributes {pto.target_arch = "a5"} { %23 = arith.cmpi eq, %18, %c0_index : index scf.if %23 { %down_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %down_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_a : !pto.tile_buf) + %down_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_a : !pto.tile_buf) %down_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%down_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %down_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%down_acc__tile_l0_a, %down_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%down_acc__tile_l0_a, %down_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_first : !pto.tile_buf) %down_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%down_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%down_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%down_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%down_mlp_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_down_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } pto.tpush_to_aiv(%down_acc__tile : !pto.tile_buf) {split = 1} } diff --git a/test/samples/Qwen3DecodeA5/out_proj_residual.pto b/test/samples/Qwen3DecodeA5/out_proj_residual.pto index 87c90e3a7a..991a3b48c9 100644 --- a/test/samples/Qwen3DecodeA5/out_proj_residual.pto +++ b/test/samples/Qwen3DecodeA5/out_proj_residual.pto @@ -45,44 +45,44 @@ module attributes {pto.target_arch = "a5"} { %23 = arith.cmpi eq, %18, %c0_index : index scf.if %23 { %o_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %o_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_a : !pto.tile_buf) + %o_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_a : !pto.tile_buf) %o_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%o_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %o_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%o_acc__tile_l0_a, %o_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%o_acc__tile_l0_a, %o_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_first : !pto.tile_buf) %o_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%o_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%o_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%o_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%a_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%w_chunk__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } pto.tpush_to_aiv(%o_acc__tile : !pto.tile_buf) {split = 1} } diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto index 231b7202ff..770c9f7404 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_1.pto @@ -40,44 +40,44 @@ module attributes {pto.target_arch = "a5"} { %22 = arith.cmpi eq, %17, %c0_index : index scf.if %22 { %q_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %q_acc__tile_l0_a = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_a : !pto.tile_buf) + %q_acc__tile_l0_a = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_a : !pto.tile_buf) %q_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%q_acc__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %q_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%q_acc__tile_l0_a, %q_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%q_acc__tile_l0_a, %q_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_first : !pto.tile_buf) %q_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%q_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%q_acc__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%q_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) - %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) + %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_b_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul.acc ins(%8, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.acc ins(%9, %6, %7 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - %10 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) } %q_proj__ssa_v0_pview = pto.partition_view %q_proj__ssa_v0_view, offsets = [%c0_index, %16], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%q_acc__tile : !pto.tile_buf) outs(%q_proj__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto index 1b419b00a6..3b589db9e2 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_10.pto @@ -33,33 +33,33 @@ module attributes {pto.target_arch = "a5"} { %w_gate__ssa_v0_pview = pto.partition_view %w_gate__ssa_v0_view, offsets = [%c0_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%w_gate__ssa_v0_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%wg_0__tile : !pto.tile_buf) %gate_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %gate_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_a : !pto.tile_buf) + %gate_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_a : !pto.tile_buf) %gate_acc__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%gate_acc__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %gate_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%gate_acc__tile_l0_a, %gate_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%gate_acc__tile_l0_a, %gate_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_first : !pto.tile_buf) %gate_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%gate_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%gate_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%gate_acc__tile_l0_c_acc : !pto.tile_buf) %wg_1__tile = pto.alloc_tile addr = %c73728_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %26 = pto.partition_view %w_gate__ssa_v0_view, offsets = [%c128_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%wg_1__tile : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) + pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) scf.for %kb__idx_v0 = %c2_index to %c64_index step %c2_index { %27 = arith.muli %kb__idx_v0, %c128_index : index %28 = arith.muli %kb__idx_v0, %c128_index : index @@ -76,30 +76,30 @@ module attributes {pto.target_arch = "a5"} { %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %33 = pto.partition_view %w_gate__ssa_v0_view, offsets = [%29, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%33 : !pto.partition_tensor_view<128x256xbf16>) outs(%9 : !pto.tile_buf) - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wg__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) - %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%17 : !pto.tile_buf) - %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) %21 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) } %gate_group__ssa_v0_pview = pto.partition_view %gate_group__ssa_v0_view, offsets = [%c0_index, %24], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%7 : !pto.tile_buf) outs(%gate_group__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto index 148552f4ea..6048201d6e 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_11.pto @@ -33,33 +33,33 @@ module attributes {pto.target_arch = "a5"} { %w_up__ssa_v0_pview = pto.partition_view %w_up__ssa_v0_view, offsets = [%c0_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%w_up__ssa_v0_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%wu_0__tile : !pto.tile_buf) %up_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %up_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_a : !pto.tile_buf) + %up_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_a : !pto.tile_buf) %up_acc__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%up_acc__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %up_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%up_acc__tile_l0_a, %up_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%up_acc__tile_l0_a, %up_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_first : !pto.tile_buf) %up_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%up_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%up_acc__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%up_acc__tile_l0_c_acc : !pto.tile_buf) %wu_1__tile = pto.alloc_tile addr = %c73728_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %26 = pto.partition_view %w_up__ssa_v0_view, offsets = [%c128_index, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%wu_1__tile : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) - %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) + %4 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk_1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %5 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) + pto.tmatmul.acc ins(%6, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.tmatmul.acc ins(%7, %4, %5 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) scf.for %kb__idx_v0 = %c2_index to %c64_index step %c2_index { %27 = arith.muli %kb__idx_v0, %c128_index : index %28 = arith.muli %kb__idx_v0, %c128_index : index @@ -76,30 +76,30 @@ module attributes {pto.target_arch = "a5"} { %9 = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c256_index : !pto.tile_buf %33 = pto.partition_view %w_up__ssa_v0_view, offsets = [%29, %23], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%33 : !pto.partition_tensor_view<128x256xbf16>) outs(%9 : !pto.tile_buf) - %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) + %10 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%11 : !pto.tile_buf) - %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) + %12 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%post_chunk__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%12 : !pto.tile_buf) %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%wu__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul.acc ins(%14, %10, %11 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%14 : !pto.tile_buf) %15 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) - %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) + pto.tmatmul.acc ins(%15, %12, %13 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + %16 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%17 : !pto.tile_buf) - %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + %18 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%8, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%9, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %16, %17 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) %21 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmatmul.acc ins(%21, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) } %up_group__ssa_v0_pview = pto.partition_view %up_group__ssa_v0_view, offsets = [%c0_index, %24], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%7 : !pto.tile_buf) outs(%up_group__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto index 642f1e5776..6b12c37e9c 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_2.pto @@ -54,81 +54,81 @@ module attributes {pto.target_arch = "a5"} { %38 = arith.cmpi eq, %32, %c0_index : index scf.if %38 { %k_acc__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %k_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_a : !pto.tile_buf) + %k_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_a : !pto.tile_buf) %k_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%k_acc__tile_l0_b : !pto.tile_buf) - %3 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) + %3 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %4 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%4 : !pto.tile_buf) %k_acc__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%k_acc__tile_l0_a, %k_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%k_acc__tile_l0_a, %k_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_first : !pto.tile_buf) %k_acc__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%k_acc__tile_l0_c_acc, %3, %4 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%k_acc__tile_l0_c_acc, %3, %4 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%k_acc__tile_l0_c_acc : !pto.tile_buf) %v_acc__tile_l0_init = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %v_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_a : !pto.tile_buf) + %v_acc__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_a : !pto.tile_buf) %v_acc__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%v_acc__tile_l0_b : !pto.tile_buf) - %5 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) + %5 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%5 : !pto.tile_buf) %6 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%6 : !pto.tile_buf) %v_acc__tile_l0_c_first = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%v_acc__tile_l0_a, %v_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%v_acc__tile_l0_a, %v_acc__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_first : !pto.tile_buf) %v_acc__tile_l0_c_acc = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%v_acc__tile_l0_c_acc, %5, %6 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%v_acc__tile_l0_c_acc, %5, %6 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%v_acc__tile_l0_c_acc : !pto.tile_buf) } else { - %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) + %7 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) %8 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%8 : !pto.tile_buf) - %9 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%9 : !pto.tile_buf) + %9 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%9 : !pto.tile_buf) %10 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wk_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%10 : !pto.tile_buf) %11 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%11, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%11 : !pto.tile_buf) + pto.tmatmul.acc ins(%11, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%11 : !pto.tile_buf) %12 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%12, %9, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%12 : !pto.tile_buf) - %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) + pto.tmatmul.acc ins(%12, %9, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%12 : !pto.tile_buf) + %13 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) %14 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%14 : !pto.tile_buf) - %15 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%15 : !pto.tile_buf) + %15 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%tile_a_i__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%15 : !pto.tile_buf) %16 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%tile_wv_i__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%16 : !pto.tile_buf) %17 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%17, %13, %14 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%17 : !pto.tile_buf) + pto.tmatmul.acc ins(%17, %13, %14 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%17 : !pto.tile_buf) %18 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%18, %15, %16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) + pto.tmatmul.acc ins(%18, %15, %16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) } - %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + %19 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) %20 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) - %21 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + %21 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) %22 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%1, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) %23 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%23, %19, %20 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%23, %19, %20 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%23 : !pto.tile_buf) %24 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%24, %21, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) - %25 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%25 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %25 = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%25 : !pto.tile_buf) %26 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%2, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%26 : !pto.tile_buf) - %27 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%27 : !pto.tile_buf) + %27 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%0, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%27 : !pto.tile_buf) %28 = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%2, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%28 : !pto.tile_buf) %29 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%29, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) %30 = pto.alloc_tile addr = %c16384_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%30, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%30 : !pto.tile_buf) + pto.tmatmul.acc ins(%30, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%30 : !pto.tile_buf) } %k_proj__ssa_v0_pview = pto.partition_view %k_proj__ssa_v0_view, offsets = [%c0_index, %31], sizes = [%c16_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<16x256xf32> pto.tstore ins(%k_acc__tile : !pto.tile_buf) outs(%k_proj__ssa_v0_pview : !pto.partition_tensor_view<16x256xf32>) diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto index 1ce3839319..97ee54d918 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_4.pto @@ -46,18 +46,18 @@ module attributes {pto.target_arch = "a5"} { %k_cache__rv_v4_dn_view_pview = pto.partition_view %k_cache__rv_v4_dn_view, offsets = [%c0_index, %18], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%k_cache__rv_v4_dn_view_pview : !pto.partition_tensor_view<128x256xbf16>) outs(%k_tile_0__tile : !pto.tile_buf) %raw_scores_0__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %raw_scores_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_a : !pto.tile_buf) + %raw_scores_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_a : !pto.tile_buf) %raw_scores_0__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_0__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded0__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_0__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %raw_scores_0__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%raw_scores_0__tile_l0_a, %raw_scores_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%raw_scores_0__tile_l0_a, %raw_scores_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_first : !pto.tile_buf) %raw_scores_0__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%raw_scores_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%raw_scores_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_0__tile_l0_c_acc : !pto.tile_buf) %19 = arith.muli %4, %c256_index : index %20 = arith.muli %sb__idx_v0, %c16_index : index %21 = arith.addi %19, %20 : index @@ -71,18 +71,18 @@ module attributes {pto.target_arch = "a5"} { %26 = pto.partition_view %k_cache__rv_v4_dn_view, offsets = [%c0_index, %25], sizes = [%c128_index, %c256_index] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<128x256xbf16>) outs(%k_tile_1__tile : !pto.tile_buf) %raw_scores_1__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - %raw_scores_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_a : !pto.tile_buf) + %raw_scores_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_a : !pto.tile_buf) %raw_scores_1__tile_l0_b = pto.alloc_tile addr = %c32768_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%raw_scores_1__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf - pto.textract ins(%q_padded1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c2048_i64 valid_row = %c16_index valid_col = %c64_index : !pto.tile_buf + pto.textract ins(%q_padded1__tile, %c0_index, %c64_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c0_i64 valid_row = %c64_index valid_col = %c256_index : !pto.tile_buf pto.textract ins(%k_tile_1__tile, %c64_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %raw_scores_1__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul ins(%raw_scores_1__tile_l0_a, %raw_scores_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%raw_scores_1__tile_l0_a, %raw_scores_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_first : !pto.tile_buf) %raw_scores_1__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c256_index : !pto.tile_buf - pto.tmatmul.acc ins(%raw_scores_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%raw_scores_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%raw_scores_1__tile_l0_c_acc : !pto.tile_buf) %28 = arith.muli %6, %c256_index : index %29 = arith.muli %sb__idx_v0, %c16_index : index %30 = arith.addi %28, %29 : index diff --git a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto index 53617d3d46..c8ffed3cde 100644 --- a/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto +++ b/test/samples/Qwen3DecodeA5/qwen3_decode_incore_6.pto @@ -37,18 +37,18 @@ module attributes {pto.target_arch = "a5"} { %v_cache__rv_v4_pview = pto.partition_view %v_cache__rv_v4_view, offsets = [%11, %c0_index], sizes = [%c256_index, %c128_index] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xbf16> pto.tload ins(%v_cache__rv_v4_pview : !pto.partition_tensor_view<256x128xbf16>) outs(%v_tile_0__tile : !pto.tile_buf) %oi_tmp_0__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - %oi_tmp_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_a : !pto.tile_buf) + %oi_tmp_0__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_a : !pto.tile_buf) %oi_tmp_0__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_0__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_0__tile_l0_b : !pto.tile_buf) - %0 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_0__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_0__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%0 : !pto.tile_buf) %1 = pto.alloc_tile addr = %c32768_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_0__tile, %c128_index, %c0_index : !pto.tile_buf, index, index) outs(%1 : !pto.tile_buf) %oi_tmp_0__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul ins(%oi_tmp_0__tile_l0_a, %oi_tmp_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%oi_tmp_0__tile_l0_a, %oi_tmp_0__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_first : !pto.tile_buf) %oi_tmp_0__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul.acc ins(%oi_tmp_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%oi_tmp_0__tile_l0_c_acc, %0, %1 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_0__tile_l0_c_acc : !pto.tile_buf) %15 = arith.muli %4, %c256_index : index %16 = arith.muli %sb__idx_v0, %c16_index : index %17 = arith.addi %15, %16 : index @@ -68,18 +68,18 @@ module attributes {pto.target_arch = "a5"} { %26 = pto.partition_view %v_cache__rv_v4_view, offsets = [%21, %c0_index], sizes = [%c256_index, %c128_index] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xbf16> pto.tload ins(%26 : !pto.partition_tensor_view<256x128xbf16>) outs(%v_tile_1__tile : !pto.tile_buf) %oi_tmp_1__tile_l0_init = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - %oi_tmp_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_a : !pto.tile_buf) + %oi_tmp_1__tile_l0_a = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_a : !pto.tile_buf) %oi_tmp_1__tile_l0_b = pto.alloc_tile addr = %c0_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_1__tile, %c0_index, %c0_index : !pto.tile_buf, index, index) outs(%oi_tmp_1__tile_l0_b : !pto.tile_buf) - %2 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.textract ins(%exp_tile_1__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) + %2 = pto.alloc_tile addr = %c4096_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf + pto.textract ins(%exp_tile_1__tile, %c0_index, %c128_index : !pto.tile_buf, index, index) outs(%2 : !pto.tile_buf) %3 = pto.alloc_tile addr = %c32768_i64 valid_row = %c128_index valid_col = %c128_index : !pto.tile_buf pto.textract ins(%v_tile_1__tile, %c128_index, %c0_index : !pto.tile_buf, index, index) outs(%3 : !pto.tile_buf) %oi_tmp_1__tile_l0_c_first = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul ins(%oi_tmp_1__tile_l0_a, %oi_tmp_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_first : !pto.tile_buf) + pto.tmatmul ins(%oi_tmp_1__tile_l0_a, %oi_tmp_1__tile_l0_b : !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_first : !pto.tile_buf) %oi_tmp_1__tile_l0_c_acc = pto.alloc_tile addr = %c0_i64 valid_row = %c16_index valid_col = %c128_index : !pto.tile_buf - pto.tmatmul.acc ins(%oi_tmp_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_acc : !pto.tile_buf) + pto.tmatmul.acc ins(%oi_tmp_1__tile_l0_c_acc, %2, %3 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%oi_tmp_1__tile_l0_c_acc : !pto.tile_buf) %28 = arith.muli %6, %c256_index : index %29 = arith.muli %sb__idx_v0, %c16_index : index %30 = arith.addi %28, %29 : index From 30f2ecad9666303e1a796f4a93494e5788d2fe46 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 16 Jun 2026 18:00:56 +0800 Subject: [PATCH 27/51] fix: preserve explicit left tile layout --- lib/PTO/IR/PTO.cpp | 12 ++++++---- lib/PTO/IR/PTOTypeDefs.cpp | 23 +------------------ .../pto/compact_left_blayout_parser_a5.pto | 2 +- test/lit/pto/left_blayout_parser_a3.pto | 4 ++-- test/lit/pto/left_blayout_parser_a5.pto | 2 +- 5 files changed, 13 insertions(+), 30 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index d66746d7ae..f314a9edad 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4579,8 +4579,10 @@ static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, if (!lhsTb || !rhsTb || !dstTb) return success(); - if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError("expects lhs to use the col_major blayout on A5"); + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError( + "expects lhs to use the row_major or col_major blayout on A5"); if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) return op->emitOpError("expects rhs to use the row_major blayout on A5"); if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) @@ -6314,9 +6316,11 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); if (*dstSpace == pto::AddressSpace::LEFT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + if ((dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) || dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); + return emitOpError( + "expects A5 left dst to use row_major or col_major blayout and row_major slayout"); } else if (*dstSpace == pto::AddressSpace::RIGHT) { if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index 1e1fe272c9..ccaefb7107 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -170,23 +170,6 @@ static std::optional resolveTileBufMemorySpace(StringRef locStr) { .Default(::std::nullopt); } -static BLayout resolveTileBufBLayout(MLIRContext *context, - AddressSpace memorySpace, - BLayout parsedLayout) { - if (memorySpace != AddressSpace::LEFT) - return parsedLayout; - - switch (getPTOParserTargetArch(context)) { - case PTOParserTargetArch::A3: - return BLayout::RowMajor; - case PTOParserTargetArch::A5: - return BLayout::ColMajor; - case PTOParserTargetArch::Unspecified: - return parsedLayout; - } - return parsedLayout; -} - TileBufConfigAttr TileBufType::getConfigAttr() const { // 情况 A:getConfig() 已经是 TileBufConfigAttr if constexpr (std::is_same_v) { @@ -501,14 +484,10 @@ static Type buildTileBufType(AsmParser &parser, return Type(); } - BLayout effectiveBLayout = - resolveTileBufBLayout(parser.getContext(), memorySpace.value(), - bl.value()); - // (32-byte alignment and boxed layout divisibility checks removed // - not general hardware requirements; validation handled elsewhere) - auto blAttr = BLayoutAttr::get(ctx, effectiveBLayout); + auto blAttr = BLayoutAttr::get(ctx, bl.value()); auto slAttr = SLayoutAttr::get(ctx, sl.value()); auto fractalAttr = IntegerAttr::get(IntegerType::get(ctx, kI32BitWidth), fields.fractal); diff --git a/test/lit/pto/compact_left_blayout_parser_a5.pto b/test/lit/pto/compact_left_blayout_parser_a5.pto index 4dad24114e..554c9fa6a9 100644 --- a/test/lit/pto/compact_left_blayout_parser_a5.pto +++ b/test/lit/pto/compact_left_blayout_parser_a5.pto @@ -8,4 +8,4 @@ module attributes {"pto.target_arch" = "a5"} { } // CHECK-LABEL: func.func @compact_left_blayout_parser_a5() { -// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> +// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> diff --git a/test/lit/pto/left_blayout_parser_a3.pto b/test/lit/pto/left_blayout_parser_a3.pto index 21ee0af7fd..41744472a5 100644 --- a/test/lit/pto/left_blayout_parser_a3.pto +++ b/test/lit/pto/left_blayout_parser_a3.pto @@ -4,8 +4,8 @@ module attributes {"pto.target_arch" = "a3"} { func.func @left_blayout_parser_a3() { %c0 = arith.constant 0 : index %src = pto.alloc_tile : !pto.tile_buf - %dst = pto.alloc_tile : !pto.tile_buf - pto.textract ins(%src, %c0, %c0 : !pto.tile_buf, index, index) outs(%dst : !pto.tile_buf) + %dst = pto.alloc_tile : !pto.tile_buf + pto.textract ins(%src, %c0, %c0 : !pto.tile_buf, index, index) outs(%dst : !pto.tile_buf) return } } diff --git a/test/lit/pto/left_blayout_parser_a5.pto b/test/lit/pto/left_blayout_parser_a5.pto index a8c7db18c2..a689597add 100644 --- a/test/lit/pto/left_blayout_parser_a5.pto +++ b/test/lit/pto/left_blayout_parser_a5.pto @@ -11,4 +11,4 @@ module attributes {"pto.target_arch" = "a5"} { } // CHECK-LABEL: AICORE void left_blayout_parser_a5() { -// CHECK: Tile +// CHECK: Tile From 20366781b11dd3cc883300894f448c59e039136b Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 16 Jun 2026 18:12:22 +0800 Subject: [PATCH 28/51] Revert "fix: gate known CANN beta VPTO SIM gaps" This reverts commit 6318ad7e23097c535a5bd651d56bdedaf94fcb23. --- .github/workflows/ci_sim.yml | 2 - .../cann-9.0.0-beta.1-sim.txt | 28 ----- .../run_host_vpto_validation_parallel.sh | 100 +++--------------- 3 files changed, 16 insertions(+), 114 deletions(-) delete mode 100644 test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 4b9e4b7ee8..4168a73208 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -508,8 +508,6 @@ jobs: PTOAS_BIN="${PTOAS_BIN}" \ DEVICE=SIM \ JOBS="${JOBS:-32}" \ - VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP=1 \ - VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE="${GITHUB_WORKSPACE}/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt" \ bash test/vpto/scripts/run_host_vpto_validation_parallel.sh - name: Run TileLang DSL CI diff --git a/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt b/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt deleted file mode 100644 index 580d79c74e..0000000000 --- a/test/vpto/known_unsupported/cann-9.0.0-beta.1-sim.txt +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# CANN 9.0.0 beta.1 Bisheng SIM cannot currently consume these LLVM 21 -# textual-IR cases. Low-precision vcvt cases require PTOAS to emit LLVM 21 -# carrier types, but beta.1 Bisheng verifies the legacy private intrinsic ABI -# while rejecting the private low-precision type spellings in textual IR. -# The SIMT cases fail in Bisheng's HiIPU backend after valid LLVM IR is handed -# to the compiler. Keep this list exact and remove entries as the CANN compiler -# catches up. -micro-op/conversion/vcvt-low-precision-roundtrip -micro-op/conversion/vcvt-low-precision-special -micro-op/simt/simt-atomic-packed-core -micro-op/simt/simt-atomic-s-core -micro-op/simt/simt-control-typed-core -micro-op/simt/simt-float-convert-core -micro-op/simt/simt-gm-memory-core -micro-op/simt/simt-ldst-policy-core -micro-op/simt/simt-memory-core -micro-op/simt/simt-packed-math-core -micro-op/simt/simt-scalar-core -micro-op/simt/simt-scalar-math-core -micro-op/simt/simt-store-tid diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh index 9d4d67c3d5..98be7d669d 100755 --- a/test/vpto/scripts/run_host_vpto_validation_parallel.sh +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -19,8 +19,6 @@ WORK_SPACE="${WORK_SPACE:-}" CASE_NAME="${CASE_NAME:-}" CASE_PREFIX="${CASE_PREFIX:-}" JOBS="${JOBS:-}" -VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP="${VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP:-0}" -VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE="${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE:-}" log() { echo "[$(date +'%F %T')] $*" @@ -71,41 +69,6 @@ fi [[ "${JOBS}" =~ ^[0-9]+$ ]] || die "JOBS must be a positive integer, got: ${JOBS}" [[ "${JOBS}" -ge 1 ]] || die "JOBS must be >= 1" -KNOWN_UNSUPPORTED_CASES=() - -trim() { - local text="$1" - text="${text#"${text%%[![:space:]]*}"}" - text="${text%"${text##*[![:space:]]}"}" - printf '%s' "${text}" -} - -load_known_unsupported_cases() { - [[ "${VPTO_SIM_ENABLE_KNOWN_UNSUPPORTED_SKIP}" == "1" ]] || return 0 - [[ -n "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" ]] || - die "VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE is required when known-unsupported skip is enabled" - [[ -f "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" ]] || - die "known-unsupported cases file not found: ${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" - - local raw line - while IFS= read -r raw || [[ -n "${raw}" ]]; do - line="$(trim "${raw%%#*}")" - [[ -n "${line}" ]] || continue - KNOWN_UNSUPPORTED_CASES+=("${line}") - done < "${VPTO_SIM_KNOWN_UNSUPPORTED_CASES_FILE}" -} - -is_known_unsupported_case() { - local case_name="$1" - local known_case - for known_case in "${KNOWN_UNSUPPORTED_CASES[@]}"; do - if [[ "${known_case}" == "${case_name}" ]]; then - return 0 - fi - done - return 1 -} - mkdir -p "${WORK_SPACE}" WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" @@ -163,32 +126,13 @@ if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi -load_known_unsupported_cases - -DISCOVERED_CASES=() -while IFS= read -r case_name; do - DISCOVERED_CASES+=("${case_name}") -done < <(discover_cases) -[[ "${#DISCOVERED_CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" - -CASES=() -SKIPPED_CASES=() -for case_name in "${DISCOVERED_CASES[@]}"; do - if is_known_unsupported_case "${case_name}"; then - SKIPPED_CASES+=("${case_name}") - else - CASES+=("${case_name}") - fi -done - -[[ "${#CASES[@]}" -gt 0 || "${#SKIPPED_CASES[@]}" -gt 0 ]] || - die "no runnable or skipped cases found under ${CASES_ROOT}" +readarray -t CASES < <(discover_cases) +[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" : > "${SUMMARY_FILE}" : > "${RUNNER_LOG}" -RUNNING_PIDS=() -RUNNING_CASES=() +declare -A PID_TO_CASE=() launch_case() { local case_name="$1" @@ -199,14 +143,12 @@ launch_case() { ) & local pid=$! - RUNNING_PIDS+=("${pid}") - RUNNING_CASES+=("${case_name}") + PID_TO_CASE["${pid}"]="${case_name}" } reap_one() { - local index="$1" - local pid="${RUNNING_PIDS[${index}]}" - local case_name="${RUNNING_CASES[${index}]}" + local pid="$1" + local case_name="${PID_TO_CASE[${pid}]}" local result="FAIL" local detail="1" @@ -217,8 +159,7 @@ reap_one() { printf '%s\t%s\t%s\n' "${case_name}" "${result}" "${detail}" >> "${SUMMARY_FILE}" log "[${case_name}] ${result} (${detail})" | tee -a "${RUNNER_LOG}" - unset 'RUNNING_PIDS['"${index}"']' - unset 'RUNNING_CASES['"${index}"']' + unset 'PID_TO_CASE['"${pid}"']' } log "=== VPTO Host Validation Parallel ===" | tee -a "${RUNNER_LOG}" @@ -226,34 +167,26 @@ log "WORK_SPACE=${WORK_SPACE}" | tee -a "${RUNNER_LOG}" log "CASE_NAME=${CASE_NAME:-}" | tee -a "${RUNNER_LOG}" log "CASE_PREFIX=${CASE_PREFIX:-}" | tee -a "${RUNNER_LOG}" log "JOBS=${JOBS}" | tee -a "${RUNNER_LOG}" -log "TOTAL_CASES=${#DISCOVERED_CASES[@]}" | tee -a "${RUNNER_LOG}" -log "RUNNABLE_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" -log "SKIPPED_CASES=${#SKIPPED_CASES[@]}" | tee -a "${RUNNER_LOG}" +log "TOTAL_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" if [[ -n "${SIM_LIB_DIR:-}" ]]; then log "SIM_LIB_DIR=${SIM_LIB_DIR}" | tee -a "${RUNNER_LOG}" fi -for case_name in "${SKIPPED_CASES[@]}"; do - printf '%s\tSKIP\tknown-unsupported\n' "${case_name}" >> "${SUMMARY_FILE}" - log "[${case_name}] SKIP (known-unsupported)" | tee -a "${RUNNER_LOG}" -done - next_index=0 -while [[ "${next_index}" -lt "${#CASES[@]}" || "${#RUNNING_PIDS[@]}" -gt 0 ]]; do - while [[ "${next_index}" -lt "${#CASES[@]}" && "${#RUNNING_PIDS[@]}" -lt "${JOBS}" ]]; do +while [[ "${next_index}" -lt "${#CASES[@]}" || "${#PID_TO_CASE[@]}" -gt 0 ]]; do + while [[ "${next_index}" -lt "${#CASES[@]}" && "${#PID_TO_CASE[@]}" -lt "${JOBS}" ]]; do launch_case "${CASES[${next_index}]}" next_index="$((next_index + 1))" done - if [[ "${#RUNNING_PIDS[@]}" -eq 0 ]]; then + if [[ "${#PID_TO_CASE[@]}" -eq 0 ]]; then continue fi while true; do - for index in "${!RUNNING_PIDS[@]}"; do - pid="${RUNNING_PIDS[${index}]}" + for pid in "${!PID_TO_CASE[@]}"; do if ! kill -0 "${pid}" 2>/dev/null; then - reap_one "${index}" + reap_one "${pid}" break 2 fi done @@ -262,14 +195,13 @@ while [[ "${next_index}" -lt "${#CASES[@]}" || "${#RUNNING_PIDS[@]}" -gt 0 ]]; d done pass_count="$(awk -F '\t' '$2 == "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" -skip_count="$(awk -F '\t' '$2 == "SKIP" {count++} END {print count + 0}' "${SUMMARY_FILE}")" -fail_count="$(awk -F '\t' '$2 == "FAIL" {count++} END {print count + 0}' "${SUMMARY_FILE}")" +fail_count="$(awk -F '\t' '$2 != "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" -log "PASS=${pass_count} SKIP=${skip_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" +log "PASS=${pass_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" log "summary: ${SUMMARY_FILE}" | tee -a "${RUNNER_LOG}" if [[ "${fail_count}" -ne 0 ]]; then die "parallel validation finished with ${fail_count} failing case(s)" fi -log "All ${pass_count} runnable case(s) passed; ${skip_count} known-unsupported case(s) skipped" | tee -a "${RUNNER_LOG}" +log "All ${pass_count} case(s) passed" | tee -a "${RUNNER_LOG}" From cd707e967d5272e76053a7bbc1e07149b41060c9 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 16 Jun 2026 18:14:08 +0800 Subject: [PATCH 29/51] Revert "fix: lower VPTO low-precision carriers for LLVM 21" This reverts commit 8d02cbace1247cfe4443d0d575a7e3680f354486. --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 133 +++--------------- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 133 +++--------------- .../vpto/low_precision_hivm_llvm_ir_abi.pto | 129 ----------------- .../simt_lowlevel_ldst_policy_vpto_llvm.pto | 8 +- 4 files changed, 47 insertions(+), 356 deletions(-) delete mode 100644 test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index e21345ba85..1ed2263f99 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -67,11 +67,16 @@ static Type getElementTypeFromVectorLike(Type type); static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { - if (pto::isPTOHiFloat8Type(type) || isa(type) || - isa(type) || - pto::isPTOFloat8E4M3LikeType(type) || - pto::isPTOFloat8E5M2LikeType(type)) + if (pto::isPTOHiFloat8Type(type)) + return Float8E4M3FNType::get(context); + if (isa(type)) return IntegerType::get(context, 8); + if (isa(type)) + return IntegerType::get(context, 8); + if (pto::isPTOFloat8E4M3LikeType(type)) + return Float8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(type)) + return Float8E5M2Type::get(context); return {}; } @@ -96,7 +101,7 @@ static Type getLowpPayloadABIElementType(Type elementType, static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, builder.getI8Type()); + {2}, Float8E4M3FNType::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -578,38 +583,6 @@ static Type getLowpPayloadCarrierType(Type vectorLikeType, return VectorType::get({*lanes}, abi->llvmElementType); } -static Type getLowpVcvtIntrinsicElementType(Type elementType, - MLIRContext *context) { - if (pto::isPTOHiFloat8Type(elementType)) - return LLVM::LLVMHiFloat8Type::get(context); - if (isa(elementType)) - return LLVM::LLVMFloat4E1M2x2Type::get(context); - if (isa(elementType)) - return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (pto::isPTOFloat8E4M3LikeType(elementType)) - return LLVM::LLVMFloat8E4M3Type::get(context); - if (pto::isPTOFloat8E5M2LikeType(elementType)) - return LLVM::LLVMFloat8E5M2Type::get(context); - return {}; -} - -static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, - MLIRContext *context); - -static Type getVcvtIntrinsicValueType(Type semanticType, Type convertedType, - MLIRContext *context) { - Type elementType = getElementTypeFromVectorLike(semanticType); - Type lowpElementType = - getLowpVcvtIntrinsicElementType(elementType, context); - if (!lowpElementType) - return getLowpIntrinsicCarrierType(semanticType, convertedType, context); - - auto lanes = getElementCountFromVectorLike(semanticType); - if (!lanes) - return {}; - return VectorType::get({*lanes}, lowpElementType); -} - static Type getPayloadABIType(Type semanticType, Type convertedType, MLIRContext *context) { if (Type carrierType = getLowpPayloadCarrierType(semanticType, context)) @@ -651,30 +624,6 @@ static Type getPackedLowpScalarMemoryType(Type semanticType, return IntegerType::get(context, 16); } -static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, - MLIRContext *context) { - if (auto vregType = dyn_cast(semanticType)) { - if (!isLowpPayloadABIElementType(vregType.getElementType())) - return convertedType; - int64_t elementCount = vregType.getElementCount(); - if (elementCount <= 0 || elementCount % 4 != 0) - return {}; - return VectorType::get({elementCount / 4}, IntegerType::get(context, 32)); - } - - if (Type packedScalarType = - getPackedLowpScalarMemoryType(semanticType, context)) - return packedScalarType; - return convertedType; -} - -static Value bitcastToType(Location loc, Value value, Type targetType, - ConversionPatternRewriter &rewriter) { - if (!targetType || targetType == value.getType()) - return value; - return rewriter.create(loc, targetType, value); -} - static Type getScalarAccessGEPElementType(Type semanticType, Builder &builder) { if (Type memoryType = @@ -7742,24 +7691,10 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); - Type inputCallType = getVcvtIntrinsicValueType( - op.getInput().getType(), adaptor.getInput().getType(), - rewriter.getContext()); - if (!inputCallType) - return rewriter.notifyMatchFailure(op, - "unsupported vcvt input carrier type"); - Type resultCallType = getVcvtIntrinsicValueType( - op.getResult().getType(), resultType, rewriter.getContext()); - if (!resultCallType) - return rewriter.notifyMatchFailure(op, - "unsupported vcvt result carrier type"); - Value input = bitcastToType(op.getLoc(), adaptor.getInput(), inputCallType, - rewriter); - SmallVector callArgs; SmallVector argTypes; - callArgs.push_back(input); - argTypes.push_back(input.getType()); + callArgs.push_back(adaptor.getInput()); + argTypes.push_back(adaptor.getInput().getType()); callArgs.push_back(adaptor.getMask()); argTypes.push_back(adaptor.getMask().getType()); @@ -7807,15 +7742,12 @@ class LowerVcvtOpPattern final : public OpConversionPattern { argTypes.push_back(partValue.getType()); } - auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultCallType}); + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); auto call = rewriter.create( - op.getLoc(), StringRef((*contract).intrinsic), - TypeRange{resultCallType}, callArgs); + op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); state.plannedDecls.push_back( PlannedDecl{std::string((*contract).intrinsic), funcType}); - Value result = - bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOp(op, call.getResults()); return success(); } @@ -9340,30 +9272,15 @@ class LowerConvertOpPattern final : public OpConversionPattern { Value saturation = getI32Constant( rewriter, op.getLoc(), static_cast(op.getSaturation())); - Type srcCallType = getLowpIntrinsicCarrierType( - op.getSrc().getType(), adaptor.getSrc().getType(), - rewriter.getContext()); - if (!srcCallType) - return rewriter.notifyMatchFailure( - op, "unsupported convert input carrier type"); - Type resultCallType = getLowpIntrinsicCarrierType( - op.getDst().getType(), resultType, rewriter.getContext()); - if (!resultCallType) - return rewriter.notifyMatchFailure( - op, "unsupported convert result carrier type"); - Value src = - bitcastToType(op.getLoc(), adaptor.getSrc(), srcCallType, rewriter); - auto funcType = rewriter.getFunctionType( - TypeRange{src.getType(), rewriter.getI32Type(), rewriter.getI32Type()}, - TypeRange{resultCallType}); + TypeRange{adaptor.getSrc().getType(), rewriter.getI32Type(), + rewriter.getI32Type()}, + TypeRange{resultType}); auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultCallType}, - ValueRange{src, rounding, saturation}); + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSrc(), rounding, saturation}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - Value result = - bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOp(op, call.getResults()); return success(); } @@ -9745,8 +9662,6 @@ static Value convertLdgCallResult(Location loc, Type valueType, if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { Value payload = rewriter.create(loc, rewriter.getI8Type(), callResult); - if (payload.getType() == convertedValueType) - return payload; return rewriter.create(loc, convertedValueType, payload); } return callResult; @@ -9880,10 +9795,8 @@ static Value convertStgValue(Location loc, Type valueType, Value value, } if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { - Value payload = value; - if (payload.getType() != rewriter.getI8Type()) - payload = - rewriter.create(loc, rewriter.getI8Type(), value); + Value payload = + rewriter.create(loc, rewriter.getI8Type(), value); return rewriter.create(loc, rewriter.getI32Type(), payload); } if (valueType.isBF16()) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 04d37d466f..fbd0298c6a 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -68,11 +68,16 @@ static Type getElementTypeFromVectorLike(Type type); static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { - if (pto::isPTOHiFloat8Type(type) || isa(type) || - isa(type) || - pto::isPTOFloat8E4M3LikeType(type) || - pto::isPTOFloat8E5M2LikeType(type)) + if (pto::isPTOHiFloat8Type(type)) + return Float8E4M3FNType::get(context); + if (isa(type)) return IntegerType::get(context, 8); + if (isa(type)) + return IntegerType::get(context, 8); + if (pto::isPTOFloat8E4M3LikeType(type)) + return Float8E4M3Type::get(context); + if (pto::isPTOFloat8E5M2LikeType(type)) + return Float8E5M2Type::get(context); return {}; } @@ -97,7 +102,7 @@ static Type getLowpPayloadABIElementType(Type elementType, static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, builder.getI8Type()); + {2}, Float8E4M3FNType::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -533,38 +538,6 @@ static Type getLowpPayloadCarrierType(Type vectorLikeType, return VectorType::get({*lanes}, abi->llvmElementType); } -static Type getLowpVcvtIntrinsicElementType(Type elementType, - MLIRContext *context) { - if (pto::isPTOHiFloat8Type(elementType)) - return LLVM::LLVMHiFloat8Type::get(context); - if (isa(elementType)) - return LLVM::LLVMFloat4E1M2x2Type::get(context); - if (isa(elementType)) - return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (pto::isPTOFloat8E4M3LikeType(elementType)) - return LLVM::LLVMFloat8E4M3Type::get(context); - if (pto::isPTOFloat8E5M2LikeType(elementType)) - return LLVM::LLVMFloat8E5M2Type::get(context); - return {}; -} - -static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, - MLIRContext *context); - -static Type getVcvtIntrinsicValueType(Type semanticType, Type convertedType, - MLIRContext *context) { - Type elementType = getElementTypeFromVectorLike(semanticType); - Type lowpElementType = - getLowpVcvtIntrinsicElementType(elementType, context); - if (!lowpElementType) - return getLowpIntrinsicCarrierType(semanticType, convertedType, context); - - auto lanes = getElementCountFromVectorLike(semanticType); - if (!lanes) - return {}; - return VectorType::get({*lanes}, lowpElementType); -} - static Type getPayloadABIType(Type semanticType, Type convertedType, MLIRContext *context) { if (Type carrierType = getLowpPayloadCarrierType(semanticType, context)) @@ -606,30 +579,6 @@ static Type getPackedLowpScalarMemoryType(Type semanticType, return IntegerType::get(context, 16); } -static Type getLowpIntrinsicCarrierType(Type semanticType, Type convertedType, - MLIRContext *context) { - if (auto vregType = dyn_cast(semanticType)) { - if (!isLowpPayloadABIElementType(vregType.getElementType())) - return convertedType; - int64_t elementCount = vregType.getElementCount(); - if (elementCount <= 0 || elementCount % 4 != 0) - return {}; - return VectorType::get({elementCount / 4}, IntegerType::get(context, 32)); - } - - if (Type packedScalarType = - getPackedLowpScalarMemoryType(semanticType, context)) - return packedScalarType; - return convertedType; -} - -static Value bitcastToType(Location loc, Value value, Type targetType, - ConversionPatternRewriter &rewriter) { - if (!targetType || targetType == value.getType()) - return value; - return rewriter.create(loc, targetType, value); -} - static Type getScalarAccessGEPElementType(Type semanticType, Builder &builder) { if (Type memoryType = @@ -7684,24 +7633,10 @@ class LowerVcvtOpPattern final : public OpConversionPattern { if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); - Type inputCallType = getVcvtIntrinsicValueType( - op.getInput().getType(), adaptor.getInput().getType(), - rewriter.getContext()); - if (!inputCallType) - return rewriter.notifyMatchFailure(op, - "unsupported vcvt input carrier type"); - Type resultCallType = getVcvtIntrinsicValueType( - op.getResult().getType(), resultType, rewriter.getContext()); - if (!resultCallType) - return rewriter.notifyMatchFailure(op, - "unsupported vcvt result carrier type"); - Value input = bitcastToType(op.getLoc(), adaptor.getInput(), inputCallType, - rewriter); - SmallVector callArgs; SmallVector argTypes; - callArgs.push_back(input); - argTypes.push_back(input.getType()); + callArgs.push_back(adaptor.getInput()); + argTypes.push_back(adaptor.getInput().getType()); callArgs.push_back(adaptor.getMask()); argTypes.push_back(adaptor.getMask().getType()); @@ -7749,15 +7684,12 @@ class LowerVcvtOpPattern final : public OpConversionPattern { argTypes.push_back(partValue.getType()); } - auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultCallType}); + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); auto call = rewriter.create( - op.getLoc(), StringRef((*contract).intrinsic), - TypeRange{resultCallType}, callArgs); + op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); state.plannedDecls.push_back( PlannedDecl{std::string((*contract).intrinsic), funcType}); - Value result = - bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOp(op, call.getResults()); return success(); } @@ -9284,30 +9216,15 @@ class LowerConvertOpPattern final : public OpConversionPattern { Value saturation = getI32Constant( rewriter, op.getLoc(), static_cast(op.getSaturation())); - Type srcCallType = getLowpIntrinsicCarrierType( - op.getSrc().getType(), adaptor.getSrc().getType(), - rewriter.getContext()); - if (!srcCallType) - return rewriter.notifyMatchFailure( - op, "unsupported convert input carrier type"); - Type resultCallType = getLowpIntrinsicCarrierType( - op.getDst().getType(), resultType, rewriter.getContext()); - if (!resultCallType) - return rewriter.notifyMatchFailure( - op, "unsupported convert result carrier type"); - Value src = - bitcastToType(op.getLoc(), adaptor.getSrc(), srcCallType, rewriter); - auto funcType = rewriter.getFunctionType( - TypeRange{src.getType(), rewriter.getI32Type(), rewriter.getI32Type()}, - TypeRange{resultCallType}); + TypeRange{adaptor.getSrc().getType(), rewriter.getI32Type(), + rewriter.getI32Type()}, + TypeRange{resultType}); auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultCallType}, - ValueRange{src, rounding, saturation}); + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSrc(), rounding, saturation}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - Value result = - bitcastToType(op.getLoc(), call.getResult(0), resultType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOp(op, call.getResults()); return success(); } @@ -9690,8 +9607,6 @@ static Value convertLdgCallResult(Location loc, Type valueType, if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { Value payload = rewriter.create(loc, rewriter.getI8Type(), callResult); - if (payload.getType() == convertedValueType) - return payload; return rewriter.create(loc, convertedValueType, payload); } return callResult; @@ -9826,10 +9741,8 @@ static Value convertStgValue(Location loc, Type valueType, Value value, } if (pto::isPTOFloat8Type(valueType) || pto::isPTOHiFloat8Type(valueType)) { - Value payload = value; - if (payload.getType() != rewriter.getI8Type()) - payload = - rewriter.create(loc, rewriter.getI8Type(), value); + Value payload = + rewriter.create(loc, rewriter.getI8Type(), value); return rewriter.create(loc, rewriter.getI32Type(), payload); } if (valueType.isBF16()) diff --git a/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto b/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto deleted file mode 100644 index 6aeca05841..0000000000 --- a/test/lit/vpto/low_precision_hivm_llvm_ir_abi.pto +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: mkdir -p %T/fake-cann/bin %T/fake-cann/tools/bisheng_compiler/bin -// RUN: touch %T/fake-cann/bin/bisheng %T/fake-cann/bin/cce-ld %T/fake-cann/bin/ld.lld %T/fake-cann/tools/bisheng_compiler/bin/bisheng -// RUN: chmod +x %T/fake-cann/bin/bisheng %T/fake-cann/bin/cce-ld %T/fake-cann/bin/ld.lld %T/fake-cann/tools/bisheng_compiler/bin/bisheng -// RUN: env ASCEND_HOME_PATH=%T/fake-cann PATH=%T/fake-cann/bin:%T/fake-cann/tools/bisheng_compiler/bin:%PATH% ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o %t.ll -// RUN: FileCheck %s --check-prefix=LLVMIR < %t.ll - -module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @lowp_vector_kernel(%f32_seed: f32, - %f16_seed: f16, - %bf16_seed: bf16, - %ub_f8: !pto.ptr, - %ub_f8e5: !pto.ptr, - %ub_hif8: !pto.ptr, - %ub_f4: !pto.ptr, - %ub_f4e2: !pto.ptr, - %dst_f32: !pto.ptr, - %dst_bf16: !pto.ptr) attributes {pto.kernel} { - %c0 = arith.constant 0 : index - pto.vecscope { - %mask_b8 = pto.pset_b8 "PAT_ALL" : !pto.mask - %mask_b16 = pto.pset_b16 "PAT_ALL" : !pto.mask - %mask_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask - - %src_f32 = pto.vdup %f32_seed, %mask_b32 : f32, !pto.mask -> !pto.vreg<64xf32> - %f8 = pto.vcvt %src_f32, %mask_b32 {rnd = "R", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> - pto.vsts %f8, %ub_f8[%c0], %mask_b8 : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask - %f8_loaded = pto.vlds %ub_f8[%c0] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> - %f8_back = pto.vcvt %f8_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %f8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - - %f8e5 = pto.vcvt %src_f32, %mask_b32 {rnd = "R", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E5M2> - pto.vsts %f8e5, %ub_f8e5[%c0], %mask_b8 : !pto.vreg<256xf8E5M2>, !pto.ptr, !pto.mask - - %hif8 = pto.vcvt %src_f32, %mask_b32 {rnd = "A", sat = "NOSAT", part = "P0"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> - pto.vsts %hif8, %ub_hif8[%c0], %mask_b8 : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask - %hif8_loaded = pto.vlds %ub_hif8[%c0] : !pto.ptr -> !pto.vreg<256x!pto.hif8> - %hif8_back = pto.vcvt %hif8_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %hif8_back, %dst_f32[%c0], %mask_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - - %src_f16 = pto.vdup %f16_seed, %mask_b16 : f16, !pto.mask -> !pto.vreg<128xf16> - %hif8_from_f16 = pto.vcvt %src_f16, %mask_b16 {rnd = "A", sat = "NOSAT", part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256x!pto.hif8> - pto.vsts %hif8_from_f16, %ub_hif8[%c0], %mask_b8 : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask - - %src_bf16 = pto.vdup %bf16_seed, %mask_b16 : bf16, !pto.mask -> !pto.vreg<128xbf16> - %f4 = pto.vcvt %src_bf16, %mask_b16 {rnd = "R", part = "P0"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<256x!pto.f4E1M2x2> - pto.vsts %f4, %ub_f4[%c0], %mask_b8 : !pto.vreg<256x!pto.f4E1M2x2>, !pto.ptr, !pto.mask - %f4_loaded = pto.vlds %ub_f4[%c0] : !pto.ptr -> !pto.vreg<256x!pto.f4E1M2x2> - %f4_back = pto.vcvt %f4_loaded, %mask_b8 {part = "P0"} : !pto.vreg<256x!pto.f4E1M2x2>, !pto.mask -> !pto.vreg<128xbf16> - pto.vsts %f4_back, %dst_bf16[%c0], %mask_b16 : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask - - %f4e2 = pto.vcvt %src_bf16, %mask_b16 {rnd = "R", part = "P0"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<256x!pto.f4E2M1x2> - pto.vsts %f4e2, %ub_f4e2[%c0], %mask_b8 : !pto.vreg<256x!pto.f4E2M1x2>, !pto.ptr, !pto.mask - } - return - } - - func.func @lowp_simt_kernel(%gm: !pto.ptr) attributes {pto.aicore} { - %c0_i64 = arith.constant 0 : i64 - %c128_i64 = arith.constant 128 : i64 - %c1_i64 = arith.constant 1 : i64 - %dim = arith.constant 1 : i32 - %threads = arith.constant 32 : i32 - %ub_i32 = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_i64 = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_hif8 = pto.castptr %c0_i64 : i64 -> !pto.ptr - pto.store_vfsimt_info %dim, %dim, %threads : i32, i32, i32 - func.call @lowp_simt_body(%ub_i32, %ub_i64, %ub_hif8) : (!pto.ptr, !pto.ptr, !pto.ptr) -> () - pto.mte_ub_gm %ub_i32, %gm, %c128_i64 - nburst(%c1_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64 - return - } - - func.func @lowp_simt_body(%dst: !pto.ptr, %dst64: !pto.ptr, %hif8_dst: !pto.ptr) attributes {pto.simt_entry} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %f2 = arith.constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32> - %f8 = pto.convert %f2 round(r) nosat : vector<2xf32> -> vector<2xf8E4M3FN> - %hif8 = pto.convert %f2 round(a) nosat : vector<2xf32> -> !pto.hif8x2 - %f8_back = pto.convert %f8 round(r) nosat : vector<2xf8E4M3FN> -> vector<2xf32> - %hif8_back = pto.convert %hif8 round(a) nosat : !pto.hif8x2 -> vector<2xf32> - %f8_bits = llvm.bitcast %f8 : vector<2xf8E4M3FN> to i16 - %f8_back_bits = llvm.bitcast %f8_back : vector<2xf32> to i64 - %hif8_back_bits = llvm.bitcast %hif8_back : vector<2xf32> to i64 - %f8_bits_i32 = arith.extui %f8_bits : i16 to i32 - pto.store %f8_bits_i32, %dst[%c0] : !pto.ptr, i32 - pto.store %hif8, %hif8_dst[%c1] : !pto.ptr, !pto.hif8x2 - pto.store %f8_back_bits, %dst64[%c0] : !pto.ptr, i64 - pto.store %hif8_back_bits, %dst64[%c1] : !pto.ptr, i64 - return - } -} - -// LLVMIR-DAG: declare <256 x float8e4m3> @llvm.hivm.vcvtff.f322f8e4m3.x -// LLVMIR-DAG: declare void @llvm.hivm.vstsx1.v256f8e4m3(<256 x i8> -// LLVMIR-DAG: declare <256 x i8> @llvm.hivm.vldsx1.v256f8e4m3 -// LLVMIR-DAG: declare <64 x float> @llvm.hivm.vcvtff.f8e4m32f32.x(<256 x float8e4m3> -// LLVMIR-DAG: declare <256 x float8e5m2> @llvm.hivm.vcvtff.f322f8e5m2.x -// LLVMIR-DAG: declare <256 x hifloat8> @llvm.hivm.vcvtff.f322hif8.x -// LLVMIR-DAG: declare <256 x hifloat8> @llvm.hivm.vcvtff.f162hif8.x -// LLVMIR-DAG: declare void @llvm.hivm.vstsx1.v256hif8(<256 x i8> -// LLVMIR-DAG: declare <256 x float4e1m2x2> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x -// LLVMIR-DAG: declare <256 x float4e2m1x2> @llvm.hivm.vcvtff2.bf162f4e2m1x2.x -// LLVMIR-DAG: declare <128 x bfloat> @llvm.hivm.vcvtff2.f4e1m2x22bf16.x(<256 x float4e1m2x2> -// LLVMIR-DAG: declare i16 @llvm.hivm.f32x2.to.f8e4m3x2 -// LLVMIR-DAG: declare i16 @llvm.hivm.f32x2.to.hif8x2 -// LLVMIR-DAG: declare <2 x float> @llvm.hivm.f8e4m3x2.to.f32x2(i16 -// LLVMIR-DAG: declare <2 x float> @llvm.hivm.hif8x2.to.f32x2(i16 -// LLVMIR: bitcast <256 x float8e4m3> {{%[0-9]+}} to <256 x i8> -// LLVMIR: bitcast <256 x i8> {{%[0-9]+}} to <256 x float8e4m3> -// LLVMIR: bitcast i16 {{%[0-9]+}} to <2 x i8> -// LLVMIR-NOT: declare <256 x f8e4m3> -// LLVMIR-NOT: declare <2 x f8e4m3> -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322f8e4m3.x -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322f8e5m2.x -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f322hif8.x -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff.f162hif8.x -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e1m2x2.x -// LLVMIR-NOT: declare <64 x i32> @llvm.hivm.vcvtff2.bf162f4e2m1x2.x -// LLVMIR-NOT: declare <256 x i8> @llvm.hivm.vcvtff.f322f8e4m3.x -// LLVMIR-NOT: declare <2 x i8> @llvm.hivm.f32x2.to.f8e4m3x2 diff --git a/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto b/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto index 5e2c07e054..b2b01574fa 100644 --- a/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto +++ b/test/lit/vpto/simt_lowlevel_ldst_policy_vpto_llvm.pto @@ -9,7 +9,7 @@ // RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @ldst_policy_kernel(%gm_i8: !pto.ptr, %gm_i16: !pto.ptr, %gm_i32: !pto.ptr, %gm_i64: !pto.ptr, %gm_f16: !pto.ptr, %gm_bf16: !pto.ptr, %gm_f32: !pto.ptr, %gm_f64: !pto.ptr, %gm_f8: !pto.ptr, %gm_hif8: !pto.ptr, %dst_i32: !pto.ptr, %dst_i64: !pto.ptr) attributes {pto.aicore} { + func.func @ldst_policy_kernel(%gm_i8: !pto.ptr, %gm_i16: !pto.ptr, %gm_i32: !pto.ptr, %gm_i64: !pto.ptr, %gm_f16: !pto.ptr, %gm_bf16: !pto.ptr, %gm_f32: !pto.ptr, %gm_f64: !pto.ptr, %dst_i32: !pto.ptr, %dst_i64: !pto.ptr) attributes {pto.aicore} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -28,8 +28,6 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> bf16 %load_f32 = pto.ldg %gm_f32[%c0] : !pto.ptr -> f32 %load_f64 = pto.ldg %gm_f64[%c0] : !pto.ptr -> f64 - %load_f8 = pto.ldg %gm_f8[%c0] l1cache(cache) l2cache(nmfv) : !pto.ptr -> f8E4M3FN - %load_hif8 = pto.ldg %gm_hif8[%c0] l1cache(uncache) l2cache(nmfv) : !pto.ptr -> !pto.hif8 pto.stg %i8, %gm_i8[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, i8 pto.stg %i16, %gm_i16[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, i16 pto.stg %load_i8, %gm_i8[%c2] l1cache(uncache) l2cache(nmfv) : !pto.ptr, i8 @@ -44,8 +42,6 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, f64 pto.stg %load_f16, %gm_f16[%c1] : !pto.ptr, f16 pto.stg %load_bf16, %gm_bf16[%c2] : !pto.ptr, bf16 - pto.stg %load_f8, %gm_f8[%c1] l1cache(cache) l2cache(nmfv) : !pto.ptr, f8E4M3FN - pto.stg %load_hif8, %gm_hif8[%c2] l1cache(uncache) l2cache(nmfv) : !pto.ptr, !pto.hif8 pto.store %load_cache, %dst_i32[%c0] : !pto.ptr, i32 pto.store %load_uncache, %dst_i64[%c0] : !pto.ptr, i64 @@ -62,12 +58,10 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind Date: Tue, 16 Jun 2026 20:40:29 +0800 Subject: [PATCH 30/51] Narrow VPTO LLVM21 upgrade changes --- .github/workflows/build_wheel.yml | 2 +- .github/workflows/build_wheel_mac.yml | 2 +- .github/workflows/ci.yml | 2 +- .github/workflows/ci_sim.yml | 23 +- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 247 +++-------------- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 249 +++--------------- lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 2 +- .../st/smoke/testcase/tmrgsort/tmrgsort.pto | 2 +- .../a5/src/st/testcase/tmrgsort/tmrgsort.pto | 2 +- tools/ptoas/ObjectEmission.cpp | 117 +------- 10 files changed, 100 insertions(+), 548 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index c8f572e749..c38015e781 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -24,7 +24,7 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git LLVM_REF: feature-vpto-llvm21 LLVM_CACHE_FLAVOR: llvm21-vpto-release-hardening-v1 diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index d54df8dfc3..37d9ad5468 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -23,7 +23,7 @@ permissions: contents: write env: - LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git LLVM_REF: feature-vpto-llvm21 LLVM_CACHE_FLAVOR: llvm21-vpto-release-v1 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 521ed88c29..12012cdb46 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -97,7 +97,7 @@ jobs: runs-on: ubuntu-22.04 env: PTOAS_CLANG_MAJOR: "15" - LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git LLVM_REF: feature-vpto-llvm21 LLVM_BUILD_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 4168a73208..7396c9a12d 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -107,7 +107,7 @@ jobs: needs.detect-vpto-sim-changes.outputs.should_run == 'true' }} env: - LLVM_REPO: https://github.com/TaoTao-real/llvm-project.git + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git LLVM_REF: feature-vpto-llvm21 PTO_INSTALL_DIR: ${{ github.workspace }}/install VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci @@ -201,25 +201,6 @@ jobs: python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi - if [[ -x /usr/bin/cc ]]; then - c_compiler=/usr/bin/cc - elif [[ -n "${CC:-}" && -x "${CC}" ]]; then - c_compiler="${CC}" - else - c_compiler="$(command -v cc)" - fi - if [[ -x /usr/bin/c++ ]]; then - cxx_compiler=/usr/bin/c++ - elif [[ -n "${CXX:-}" && -x "${CXX}" ]]; then - cxx_compiler="${CXX}" - else - cxx_compiler="$(command -v c++)" - fi - echo "PTOAS_CMAKE_C_COMPILER=${c_compiler}" >> "${GITHUB_ENV}" - echo "PTOAS_CMAKE_CXX_COMPILER=${cxx_compiler}" >> "${GITHUB_ENV}" - "${c_compiler}" --version | head -n 1 - "${cxx_compiler}" --version | head -n 1 - - name: Clean CI work dirs shell: bash run: | @@ -297,8 +278,6 @@ jobs: export CXX=g++ # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). LLVM_BUILD_DIR="${LLVM_DIR}" \ - CMAKE_C_COMPILER="${PTOAS_CMAKE_C_COMPILER}" \ - CMAKE_CXX_COMPILER="${PTOAS_CMAKE_CXX_COMPILER}" \ PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" \ python3 -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 1ed2263f99..8de5c59b84 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -68,15 +68,15 @@ static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { if (pto::isPTOHiFloat8Type(type)) - return Float8E4M3FNType::get(context); + return LLVM::LLVMHiFloat8Type::get(context); if (isa(type)) - return IntegerType::get(context, 8); + return LLVM::LLVMFloat4E1M2x2Type::get(context); if (isa(type)) - return IntegerType::get(context, 8); + return LLVM::LLVMFloat4E2M1x2Type::get(context); if (pto::isPTOFloat8E4M3LikeType(type)) - return Float8E4M3Type::get(context); + return LLVM::LLVMFloat8E4M3Type::get(context); if (pto::isPTOFloat8E5M2LikeType(type)) - return Float8E5M2Type::get(context); + return LLVM::LLVMFloat8E5M2Type::get(context); return {}; } @@ -86,22 +86,10 @@ static Type getLLVMCompatibleVectorType(ArrayRef shape, return VectorType::get(shape, elementType, scalableDims); } -static bool isLowpPayloadABIElementType(Type type) { - return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || - pto::isPTOFloat4PackedType(type); -} - -static Type getLowpPayloadABIElementType(Type elementType, - MLIRContext *context) { - if (!isLowpPayloadABIElementType(elementType)) - return {}; - return IntegerType::get(context, 8); -} - static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, Float8E4M3FNType::get(builder.getContext())); + {2}, LLVM::LLVMHiFloat8Type::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -129,6 +117,10 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, return builder.getI16Type(); if (pto::isPTOLowPrecisionType(type)) return builder.getI8Type(); + if (isa(type)) + return builder.getI8Type(); if (auto vecType = dyn_cast(type)) { Type normalizedElement = @@ -145,12 +137,8 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, static Type convertVPTOType(Type type, Builder &builder) { if (auto vecType = dyn_cast(type)) { - Type sourceElementType = vecType.getElementType(); - Type elementType = getLowpPayloadABIElementType(sourceElementType, - builder.getContext()); - if (!elementType) - elementType = normalizePayloadTypeForLLVMLowering(sourceElementType, - builder); + Type elementType = + normalizePayloadTypeForLLVMLowering(vecType.getElementType(), builder); return getLLVMCompatibleVectorType({vecType.getElementCount()}, elementType); } @@ -192,13 +180,7 @@ static unsigned getNaturalByteAlignment(Type type) { } static bool hasVPTOConvertibleType(Type type) { - if (isa(type)) - return true; - if (pto::isPTOLowPrecisionType(type)) - return true; - if (Type elementType = getElementTypeFromVectorLike(type)) - return pto::isPTOLowPrecisionType(elementType); - return false; + return isa(type); } static bool hasVPTOConvertibleType(TypeRange types) { @@ -554,7 +536,8 @@ static std::string getMemoryElementTypeFragment(Type type) { } static bool isLowpPayloadElementType(Type type) { - return isLowpPayloadABIElementType(type); + return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || + pto::isPTOFloat4PackedType(type); } struct LowpPayloadABI { @@ -564,10 +547,9 @@ struct LowpPayloadABI { static std::optional getLowpPayloadABI(Type elementType, MLIRContext *context) { - Type carrierElementType = getLowpPayloadABIElementType(elementType, context); - if (!carrierElementType) + if (!isLowpPayloadElementType(elementType)) return std::nullopt; - return LowpPayloadABI{carrierElementType, "u8"}; + return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; } static Type getLowpPayloadCarrierType(Type vectorLikeType, @@ -610,55 +592,6 @@ static Value castFromPayloadABI( return rewriter.create(loc, convertedType, value); } -static Type getPackedLowpScalarMemoryType(Type semanticType, - MLIRContext *context) { - if (pto::isPTOHiFloat8x2Type(semanticType)) - return IntegerType::get(context, 16); - - auto vecType = dyn_cast(semanticType); - if (!vecType || vecType.getRank() != 1 || vecType.getDimSize(0) != 2 || - llvm::is_contained(vecType.getScalableDims(), true)) - return {}; - if (!isLowpPayloadABIElementType(vecType.getElementType())) - return {}; - return IntegerType::get(context, 16); -} - -static Type getScalarAccessGEPElementType(Type semanticType, - Builder &builder) { - if (Type memoryType = - getPackedLowpScalarMemoryType(semanticType, builder.getContext())) - return memoryType; - return normalizeGEPElementTypeForLLVMLowering(semanticType, builder); -} - -static Type getScalarAccessLoadStoreType(Type semanticType, - Type convertedType, - MLIRContext *context) { - if (Type memoryType = getPackedLowpScalarMemoryType(semanticType, context)) - return memoryType; - return convertedType; -} - -static Value castToScalarAccessMemoryType(Location loc, Value value, - Type semanticType, - ConversionPatternRewriter &rewriter) { - Type memoryType = - getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()); - if (!memoryType || memoryType == value.getType()) - return value; - return rewriter.create(loc, memoryType, value); -} - -static Value castFromScalarAccessMemoryType( - Location loc, Value value, Type semanticType, Type convertedType, - ConversionPatternRewriter &rewriter) { - if (!getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()) || - value.getType() == convertedType) - return value; - return rewriter.create(loc, convertedType, value); -} - static std::string getAtomicElementTypeFragment(Type type, Attribute signednessAttr) { if (auto vecType = dyn_cast(type)) { @@ -9489,8 +9422,6 @@ class ConvertPtoLoadScalarOp final if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load_scalar result type"); - Type loadValueType = getScalarAccessLoadStoreType( - op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9500,42 +9431,19 @@ class ConvertPtoLoadScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + normalizeGEPElementTypeForLLVMLowering( + convertedValueType, rewriter), adaptor.getPtr(), ValueRange{offset}); } - auto loaded = rewriter.create( - op.getLoc(), loadValueType, elemPtr, - getNaturalByteAlignment(loadValueType)); - Value result = castFromScalarAccessMemoryType( - op.getLoc(), loaded.getResult(), op.getValue().getType(), - convertedValueType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalByteAlignment(convertedValueType)); return success(); } }; -static FailureOr recoverConvertedValue(Value value, Type sourceType, - const TypeConverter &converter) { - Type convertedType = converter.convertType(sourceType); - if (!convertedType) - return failure(); - - for (unsigned depth = 0; depth < 4; ++depth) { - if (value.getType() == convertedType) - return value; - auto castOp = value.getDefiningOp(); - if (!castOp || castOp->getNumOperands() != 1 || - castOp->getNumResults() != 1) - break; - value = castOp.getOperand(0); - } - return failure(); -} - class ConvertPtoStoreScalarOp final : public OpConversionPattern { public: @@ -9556,22 +9464,14 @@ class ConvertPtoStoreScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), + normalizeGEPElementTypeForLLVMLowering( + adaptor.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - FailureOr value = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(value)) - return rewriter.notifyMatchFailure(op, "could not convert store value"); - - Value storedValue = castToScalarAccessMemoryType( - op.getLoc(), *value, op.getValue().getType(), rewriter); - rewriter.create( - op.getLoc(), storedValue, elemPtr, - getNaturalByteAlignment(storedValue.getType())); + rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, + getNaturalByteAlignment(adaptor.getValue().getType())); rewriter.eraseOp(op); return success(); } @@ -9594,8 +9494,6 @@ class ConvertPtoLoadOp final : public OpConversionPattern { getTypeConverter()->convertType(op.getValue().getType()); if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load result type"); - Type loadValueType = getScalarAccessLoadStoreType( - op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9605,20 +9503,14 @@ class ConvertPtoLoadOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + convertedValueType, adaptor.getPtr(), ValueRange{offset}); } - auto loaded = rewriter.create( - op.getLoc(), loadValueType, elemPtr, - getNaturalByteAlignment(loadValueType)); - Value result = castFromScalarAccessMemoryType( - op.getLoc(), loaded.getResult(), op.getValue().getType(), - convertedValueType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalByteAlignment(convertedValueType)); return success(); } }; @@ -9700,9 +9592,7 @@ class ConvertPtoLdgOp final : public OpConversionPattern { ValueRange{offset}); } - auto ptrTy = dyn_cast(op.getPtr().getType()); - if (!ptrTy) - return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); + auto ptrTy = cast(op.getPtr().getType()); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9763,22 +9653,13 @@ class ConvertPtoStoreOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + adaptor.getValue().getType(), adaptor.getPtr(), ValueRange{offset}); } - FailureOr value = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(value)) - return rewriter.notifyMatchFailure(op, "could not convert store value"); - - Value storedValue = castToScalarAccessMemoryType( - op.getLoc(), *value, op.getValue().getType(), rewriter); rewriter.replaceOpWithNewOp( - op, storedValue, elemPtr, - getNaturalByteAlignment(storedValue.getType())); + op, adaptor.getValue(), elemPtr, + getNaturalByteAlignment(adaptor.getValue().getType())); return success(); } }; @@ -9836,9 +9717,7 @@ class ConvertPtoStgOp final : public OpConversionPattern { adaptor.getPtr(), ValueRange{offset}); } - auto ptrTy = dyn_cast(op.getPtr().getType()); - if (!ptrTy) - return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); + auto ptrTy = cast(op.getPtr().getType()); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9858,12 +9737,8 @@ class ConvertPtoStgOp final : public OpConversionPattern { : pto::StL2Cache::NMFV; Value modeValue = getI32Constant(rewriter, op.getLoc(), static_cast(mode)); - FailureOr convertedValue = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(convertedValue)) - return rewriter.notifyMatchFailure(op, "could not convert stg value"); Value storedValue = convertStgValue(op.getLoc(), op.getValue().getType(), - *convertedValue, rewriter); + adaptor.getValue(), rewriter); auto funcType = rewriter.getFunctionType(TypeRange{ptr->getType(), storedValue.getType(), rewriter.getI32Type()}, @@ -9888,9 +9763,7 @@ class ConvertVPTOTypedCarrierOp final : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (isa(op)) + if (isa(op)) return failure(); if (!hasVPTOConvertibleType(op->getOperandTypes()) && !hasVPTOConvertibleType(op->getResultTypes())) @@ -10349,49 +10222,6 @@ static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { return type; } -static bool isI8VectorToLowpVectorMaterialization( - UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return false; - - auto sourceVec = dyn_cast(castOp.getOperand(0).getType()); - auto resultVec = dyn_cast(castOp.getResult(0).getType()); - if (!sourceVec || !resultVec || sourceVec.getShape() != resultVec.getShape() || - sourceVec.getScalableDims() != resultVec.getScalableDims()) - return false; - - auto sourceElement = dyn_cast(sourceVec.getElementType()); - return sourceElement && sourceElement.getWidth() == 8 && - pto::isPTOLowPrecisionType(resultVec.getElementType()); -} - -static void foldLowpVectorMaterializationCastsForLLVMExport(ModuleOp module) { - SmallVector casts; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (isI8VectorToLowpVectorMaterialization(castOp)) - casts.push_back(castOp); - }); - - for (UnrealizedConversionCastOp castOp : casts) { - if (!castOp) - continue; - SmallVector users(castOp->getUsers()); - for (Operation *user : users) { - auto bitcastOp = dyn_cast(user); - if (!bitcastOp) - continue; - OpBuilder builder(bitcastOp); - Value replacement = builder.create( - bitcastOp.getLoc(), bitcastOp.getResult().getType(), - castOp.getOperand(0)); - bitcastOp.getResult().replaceAllUsesWith(replacement); - bitcastOp.erase(); - } - if (castOp->use_empty()) - castOp.erase(); - } -} - static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { Builder builder(module.getContext()); @@ -10615,6 +10445,7 @@ static void applySimtEntryCallingConvention( const llvm::StringSet &simtEntryNames) { for (llvm::Function &function : llvmModule) { if (simtEntryNames.contains(function.getName())) { + function.setCallingConv(llvm::CallingConv::SimtEntry); function.addFnAttr(llvm::Attribute::NoInline); // Match Bisheng's C++ frontend shape for SIMT outlined bodies. The // exported wrapper owns the real kernel metadata, while the SIMT body is @@ -10636,6 +10467,7 @@ static void applySimtEntryCallingConvention( auto *callee = call->getCalledFunction(); if (!callee || !simtEntryNames.contains(callee->getName())) continue; + call->setCallingConv(llvm::CallingConv::SimtEntry); } } } @@ -10709,7 +10541,6 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; return failure(); } - foldLowpVectorMaterializationCastsForLLVMExport(clonedModule); return emit(clonedModule); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index fbd0298c6a..75ed0cd48e 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -69,15 +69,15 @@ static std::optional getElementCountFromVectorLike(Type type); static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { if (pto::isPTOHiFloat8Type(type)) - return Float8E4M3FNType::get(context); + return LLVM::LLVMHiFloat8Type::get(context); if (isa(type)) - return IntegerType::get(context, 8); + return LLVM::LLVMFloat4E1M2x2Type::get(context); if (isa(type)) - return IntegerType::get(context, 8); + return LLVM::LLVMFloat4E2M1x2Type::get(context); if (pto::isPTOFloat8E4M3LikeType(type)) - return Float8E4M3Type::get(context); + return LLVM::LLVMFloat8E4M3Type::get(context); if (pto::isPTOFloat8E5M2LikeType(type)) - return Float8E5M2Type::get(context); + return LLVM::LLVMFloat8E5M2Type::get(context); return {}; } @@ -87,22 +87,10 @@ static Type getLLVMCompatibleVectorType(ArrayRef shape, return VectorType::get(shape, elementType, scalableDims); } -static bool isLowpPayloadABIElementType(Type type) { - return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || - pto::isPTOFloat4PackedType(type); -} - -static Type getLowpPayloadABIElementType(Type elementType, - MLIRContext *context) { - if (!isLowpPayloadABIElementType(elementType)) - return {}; - return IntegerType::get(context, 8); -} - static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) return getLLVMCompatibleVectorType( - {2}, Float8E4M3FNType::get(builder.getContext())); + {2}, LLVM::LLVMHiFloat8Type::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -130,6 +118,10 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, return builder.getI16Type(); if (pto::isPTOLowPrecisionType(type)) return builder.getI8Type(); + if (isa(type)) + return builder.getI8Type(); if (auto vecType = dyn_cast(type)) { Type normalizedElement = @@ -146,12 +138,8 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, static Type convertVPTOType(Type type, Builder &builder) { if (auto vecType = dyn_cast(type)) { - Type sourceElementType = vecType.getElementType(); - Type elementType = getLowpPayloadABIElementType(sourceElementType, - builder.getContext()); - if (!elementType) - elementType = normalizePayloadTypeForLLVMLowering(sourceElementType, - builder); + Type elementType = + normalizePayloadTypeForLLVMLowering(vecType.getElementType(), builder); return getLLVMCompatibleVectorType({vecType.getElementCount()}, elementType); } @@ -193,13 +181,7 @@ static unsigned getNaturalByteAlignment(Type type) { } static bool hasVPTOConvertibleType(Type type) { - if (isa(type)) - return true; - if (pto::isPTOLowPrecisionType(type)) - return true; - if (Type elementType = getElementTypeFromVectorLike(type)) - return pto::isPTOLowPrecisionType(elementType); - return false; + return isa(type); } static bool hasVPTOConvertibleType(TypeRange types) { @@ -509,7 +491,8 @@ static std::string getMemoryElementTypeFragment(Type type) { } static bool isLowpPayloadElementType(Type type) { - return isLowpPayloadABIElementType(type); + return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || + pto::isPTOFloat4PackedType(type); } struct LowpPayloadABI { @@ -519,10 +502,9 @@ struct LowpPayloadABI { static std::optional getLowpPayloadABI(Type elementType, MLIRContext *context) { - Type carrierElementType = getLowpPayloadABIElementType(elementType, context); - if (!carrierElementType) + if (!isLowpPayloadElementType(elementType)) return std::nullopt; - return LowpPayloadABI{carrierElementType, "u8"}; + return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; } static Type getLowpPayloadCarrierType(Type vectorLikeType, @@ -565,55 +547,6 @@ static Value castFromPayloadABI( return rewriter.create(loc, convertedType, value); } -static Type getPackedLowpScalarMemoryType(Type semanticType, - MLIRContext *context) { - if (pto::isPTOHiFloat8x2Type(semanticType)) - return IntegerType::get(context, 16); - - auto vecType = dyn_cast(semanticType); - if (!vecType || vecType.getRank() != 1 || vecType.getDimSize(0) != 2 || - llvm::is_contained(vecType.getScalableDims(), true)) - return {}; - if (!isLowpPayloadABIElementType(vecType.getElementType())) - return {}; - return IntegerType::get(context, 16); -} - -static Type getScalarAccessGEPElementType(Type semanticType, - Builder &builder) { - if (Type memoryType = - getPackedLowpScalarMemoryType(semanticType, builder.getContext())) - return memoryType; - return normalizeGEPElementTypeForLLVMLowering(semanticType, builder); -} - -static Type getScalarAccessLoadStoreType(Type semanticType, - Type convertedType, - MLIRContext *context) { - if (Type memoryType = getPackedLowpScalarMemoryType(semanticType, context)) - return memoryType; - return convertedType; -} - -static Value castToScalarAccessMemoryType(Location loc, Value value, - Type semanticType, - ConversionPatternRewriter &rewriter) { - Type memoryType = - getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()); - if (!memoryType || memoryType == value.getType()) - return value; - return rewriter.create(loc, memoryType, value); -} - -static Value castFromScalarAccessMemoryType( - Location loc, Value value, Type semanticType, Type convertedType, - ConversionPatternRewriter &rewriter) { - if (!getPackedLowpScalarMemoryType(semanticType, rewriter.getContext()) || - value.getType() == convertedType) - return value; - return rewriter.create(loc, convertedType, value); -} - static std::string getAtomicElementTypeFragment(Type type, Attribute signednessAttr) { if (auto vecType = dyn_cast(type)) { @@ -9433,8 +9366,6 @@ class ConvertPtoLoadScalarOp final if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load_scalar result type"); - Type loadValueType = getScalarAccessLoadStoreType( - op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9444,42 +9375,19 @@ class ConvertPtoLoadScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + normalizeGEPElementTypeForLLVMLowering( + convertedValueType, rewriter), adaptor.getPtr(), ValueRange{offset}); } - auto loaded = rewriter.create( - op.getLoc(), loadValueType, elemPtr, - getNaturalByteAlignment(loadValueType)); - Value result = castFromScalarAccessMemoryType( - op.getLoc(), loaded.getResult(), op.getValue().getType(), - convertedValueType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalByteAlignment(convertedValueType)); return success(); } }; -static FailureOr recoverConvertedValue(Value value, Type sourceType, - const TypeConverter &converter) { - Type convertedType = converter.convertType(sourceType); - if (!convertedType) - return failure(); - - for (unsigned depth = 0; depth < 4; ++depth) { - if (value.getType() == convertedType) - return value; - auto castOp = value.getDefiningOp(); - if (!castOp || castOp->getNumOperands() != 1 || - castOp->getNumResults() != 1) - break; - value = castOp.getOperand(0); - } - return failure(); -} - class ConvertPtoStoreScalarOp final : public OpConversionPattern { public: @@ -9500,22 +9408,14 @@ class ConvertPtoStoreScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), + normalizeGEPElementTypeForLLVMLowering( + adaptor.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - FailureOr value = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(value)) - return rewriter.notifyMatchFailure(op, "could not convert store value"); - - Value storedValue = castToScalarAccessMemoryType( - op.getLoc(), *value, op.getValue().getType(), rewriter); - rewriter.create( - op.getLoc(), storedValue, elemPtr, - getNaturalByteAlignment(storedValue.getType())); + rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, + getNaturalByteAlignment(adaptor.getValue().getType())); rewriter.eraseOp(op); return success(); } @@ -9538,8 +9438,6 @@ class ConvertPtoLoadOp final : public OpConversionPattern { getTypeConverter()->convertType(op.getValue().getType()); if (!convertedValueType) return rewriter.notifyMatchFailure(op, "could not convert load result type"); - Type loadValueType = getScalarAccessLoadStoreType( - op.getValue().getType(), convertedValueType, rewriter.getContext()); Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) @@ -9549,20 +9447,14 @@ class ConvertPtoLoadOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + convertedValueType, adaptor.getPtr(), ValueRange{offset}); } - auto loaded = rewriter.create( - op.getLoc(), loadValueType, elemPtr, - getNaturalByteAlignment(loadValueType)); - Value result = castFromScalarAccessMemoryType( - op.getLoc(), loaded.getResult(), op.getValue().getType(), - convertedValueType, rewriter); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalByteAlignment(convertedValueType)); return success(); } @@ -9645,9 +9537,7 @@ class ConvertPtoLdgOp final : public OpConversionPattern { ValueRange{offset}); } - auto ptrTy = dyn_cast(op.getPtr().getType()); - if (!ptrTy) - return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); + auto ptrTy = cast(op.getPtr().getType()); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9708,22 +9598,13 @@ class ConvertPtoStoreOp final : public OpConversionPattern { Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - getScalarAccessGEPElementType( - op.getValue().getType(), - rewriter), + adaptor.getValue().getType(), adaptor.getPtr(), ValueRange{offset}); } - FailureOr value = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(value)) - return rewriter.notifyMatchFailure(op, "could not convert store value"); - - Value storedValue = castToScalarAccessMemoryType( - op.getLoc(), *value, op.getValue().getType(), rewriter); rewriter.replaceOpWithNewOp( - op, storedValue, elemPtr, - getNaturalByteAlignment(storedValue.getType())); + op, adaptor.getValue(), elemPtr, + getNaturalByteAlignment(adaptor.getValue().getType())); return success(); } @@ -9777,14 +9658,12 @@ class ConvertPtoStgOp final : public OpConversionPattern { if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, normalizeGEPElementTypeForLLVMLowering( - op.getValue().getType(), + adaptor.getValue().getType(), rewriter), adaptor.getPtr(), ValueRange{offset}); } - auto ptrTy = dyn_cast(op.getPtr().getType()); - if (!ptrTy) - return rewriter.notifyMatchFailure(op, "expected PTO pointer source type"); + auto ptrTy = cast(op.getPtr().getType()); FailureOr ptr = reinterpretPointerToAddrSpace( op, elemPtr, static_cast(ptrTy.getMemorySpace().getAddressSpace())); @@ -9804,12 +9683,8 @@ class ConvertPtoStgOp final : public OpConversionPattern { : pto::StL2Cache::NMFV; Value modeValue = getI32Constant(rewriter, op.getLoc(), static_cast(mode)); - FailureOr convertedValue = recoverConvertedValue( - adaptor.getValue(), op.getValue().getType(), *getTypeConverter()); - if (failed(convertedValue)) - return rewriter.notifyMatchFailure(op, "could not convert stg value"); Value storedValue = convertStgValue(op.getLoc(), op.getValue().getType(), - *convertedValue, rewriter); + adaptor.getValue(), rewriter); auto funcType = rewriter.getFunctionType(TypeRange{ptr->getType(), storedValue.getType(), rewriter.getI32Type()}, @@ -9834,9 +9709,7 @@ class ConvertVPTOTypedCarrierOp final : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (isa(op)) + if (isa(op)) return failure(); if (!hasVPTOConvertibleType(op->getOperandTypes()) && !hasVPTOConvertibleType(op->getResultTypes())) @@ -10295,49 +10168,6 @@ static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { return type; } -static bool isI8VectorToLowpVectorMaterialization( - UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return false; - - auto sourceVec = dyn_cast(castOp.getOperand(0).getType()); - auto resultVec = dyn_cast(castOp.getResult(0).getType()); - if (!sourceVec || !resultVec || sourceVec.getShape() != resultVec.getShape() || - sourceVec.getScalableDims() != resultVec.getScalableDims()) - return false; - - auto sourceElement = dyn_cast(sourceVec.getElementType()); - return sourceElement && sourceElement.getWidth() == 8 && - pto::isPTOLowPrecisionType(resultVec.getElementType()); -} - -static void foldLowpVectorMaterializationCastsForLLVMExport(ModuleOp module) { - SmallVector casts; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (isI8VectorToLowpVectorMaterialization(castOp)) - casts.push_back(castOp); - }); - - for (UnrealizedConversionCastOp castOp : casts) { - if (!castOp) - continue; - SmallVector users(castOp->getUsers()); - for (Operation *user : users) { - auto bitcastOp = dyn_cast(user); - if (!bitcastOp) - continue; - OpBuilder builder(bitcastOp); - Value replacement = builder.create( - bitcastOp.getLoc(), bitcastOp.getResult().getType(), - castOp.getOperand(0)); - bitcastOp.getResult().replaceAllUsesWith(replacement); - bitcastOp.erase(); - } - if (castOp->use_empty()) - castOp.erase(); - } -} - static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { Builder builder(module.getContext()); @@ -10576,6 +10406,7 @@ static void applySimtEntryCallingConvention( const llvm::StringSet &simtEntryNames) { for (llvm::Function &function : llvmModule) { if (simtEntryNames.contains(function.getName())) { + function.setCallingConv(llvm::CallingConv::SimtEntry); function.addFnAttr(llvm::Attribute::NoInline); // Match Bisheng's C++ frontend shape for SIMT outlined bodies. The // exported wrapper owns the real kernel metadata, while the SIMT body is @@ -10597,6 +10428,7 @@ static void applySimtEntryCallingConvention( auto *callee = call->getCalledFunction(); if (!callee || !simtEntryNames.contains(callee->getName())) continue; + call->setCallingConv(llvm::CallingConv::SimtEntry); } } } @@ -10670,7 +10502,6 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; return failure(); } - foldLowpVectorMaterializationCastsForLLVMExport(clonedModule); return emit(clonedModule); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp index 402c4272eb..73919b9ae8 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -610,7 +610,7 @@ void attachHIVMKernelAnnotations(llvm::Module &llvmModule, for (llvm::Function &function : llvmModule) { if (function.isDeclaration()) continue; - if (simtConfigByName.contains(function.getName())) { + if (function.getCallingConv() == llvm::CallingConv::SimtEntry) { uint32_t maxThreads = kDefaultSimtMaxThreads; uint32_t maxRegisters = kDefaultSimtMaxRegisters; if (auto it = simtConfigByName.find(function.getName()); diff --git a/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto b/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto index cf27d05033..ed90162b9c 100644 --- a/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto +++ b/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto @@ -9,7 +9,7 @@ // TileLang ST kernels for pto.tmrgsort Format1: single list internal block sorting. // Input is divided into 4 blocks, each block sorted, then merged. // Output: interleaved (sorted_value, original_index) pairs. -// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand // to produce LLVM IR. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto index 76eeaaf182..b9bb701bff 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto @@ -9,7 +9,7 @@ // TileLang ST kernels for pto.tmrgsort Format1: single list internal block sorting. // Input is divided into 4 blocks, each block sorted, then merged. // Output: interleaved (sorted_value, original_index) pairs. -// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand // to produce LLVM IR. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { diff --git a/tools/ptoas/ObjectEmission.cpp b/tools/ptoas/ObjectEmission.cpp index 41230335e5..09f7604aa5 100644 --- a/tools/ptoas/ObjectEmission.cpp +++ b/tools/ptoas/ObjectEmission.cpp @@ -63,7 +63,7 @@ static bool writeTextFile(StringRef path, StringRef content, static void stripUnsupportedBishengAttrs(llvm::Module &module) { for (llvm::Function &function : module) { - // LLVM prints memory effect attributes in textual form like + // LLVM 19 prints memory effect attributes in textual form like // `memory(none)`. beta.1 Bisheng cannot parse that syntax, so remove only // the unsupported memory-effect attribute before serializing the module. function.setAttributes( @@ -72,113 +72,24 @@ static void stripUnsupportedBishengAttrs(llvm::Module &module) { } } -static std::optional findVectorTypeStart(StringRef text, - size_t typeEnd) { - unsigned depth = 0; - for (size_t index = typeEnd; index > 0; --index) { - char c = text[index - 1]; - if (c == '\n' || c == '\r') - return std::nullopt; - if (c == '>') - ++depth; - else if (c == '<') { - if (depth == 0) - return std::nullopt; - --depth; - if (depth == 0) - return index - 1; - } - } - return std::nullopt; -} - -static std::optional getFixedVectorElementCount(StringRef vectorType) { - vectorType = vectorType.trim(); - if (!vectorType.consume_front("<") || !vectorType.consume_back(">")) - return std::nullopt; - vectorType = vectorType.trim(); - if (vectorType.starts_with("vscale")) - return std::nullopt; - - auto [countText, elementType] = vectorType.split(" x "); - unsigned count = 0; - if (countText.empty() || elementType.empty() || - countText.getAsInteger(10, count) || count == 0) - return std::nullopt; - return count; -} - -static std::string expandFixedVectorSplatConstants(StringRef input) { - constexpr StringRef marker = "splat ("; - constexpr unsigned maxExpandedElements = 4096; - - std::string output; - size_t cursor = 0; - size_t searchFrom = 0; - - while (true) { - size_t splatPos = input.find(marker, searchFrom); - if (splatPos == StringRef::npos) - break; - - size_t typeEnd = splatPos; - while (typeEnd > 0 && - std::isspace(static_cast(input[typeEnd - 1]))) - --typeEnd; - if (typeEnd == 0 || input[typeEnd - 1] != '>') { - searchFrom = splatPos + marker.size(); - continue; - } - - std::optional typeStart = findVectorTypeStart(input, typeEnd); - if (!typeStart) { - searchFrom = splatPos + marker.size(); - continue; - } - - std::optional elementCount = - getFixedVectorElementCount(input.slice(*typeStart, typeEnd)); - if (!elementCount || *elementCount > maxExpandedElements) { - searchFrom = splatPos + marker.size(); - continue; - } - - size_t valueStart = splatPos + marker.size(); - size_t valueEnd = input.find(')', valueStart); - size_t lineEnd = input.find('\n', valueStart); - if (valueEnd == StringRef::npos || - (lineEnd != StringRef::npos && valueEnd > lineEnd)) { - searchFrom = splatPos + marker.size(); - continue; - } - - StringRef element = input.slice(valueStart, valueEnd); - output.append(input.data() + cursor, typeEnd - cursor); - output.append(" <"); - for (unsigned index = 0; index < *elementCount; ++index) { - if (index != 0) - output.append(", "); - output.append(element.data(), element.size()); - } - output.append(">"); - - cursor = valueEnd + 1; - searchFrom = cursor; - } - - output.append(input.data() + cursor, input.size() - cursor); - return output; -} - static bool writeLLVMModuleFile(llvm::Module &module, StringRef path, llvm::raw_ostream &diagOS) { + std::error_code ec; + llvm::raw_fd_ostream os(path, ec, llvm::sys::fs::OF_Text); + if (ec) { + diagOS << "Error: failed to open " << path << " for write: " + << ec.message() << "\n"; + return false; + } stripUnsupportedBishengAttrs(module); - std::string llvmIR; - llvm::raw_string_ostream os(llvmIR); module.print(os, nullptr); os.flush(); - llvmIR = expandFixedVectorSplatConstants(llvmIR); - return writeTextFile(path, llvmIR, diagOS); + if (os.has_error()) { + diagOS << "Error: failed to write LLVM module to " << path << "\n"; + os.clear_error(); + return false; + } + return true; } static std::string sanitizeModuleId(llvm::StringRef raw) { From 74ba6d20ecb5fce1f8f1e1b36313f9a879265961 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 18 Jun 2026 12:52:58 +0800 Subject: [PATCH 31/51] chore: narrow LLVM21 dependency references --- .gitignore | 2 -- README.md | 6 +++--- README_en.md | 6 +++--- ReleaseNotes.md | 2 +- docker/Dockerfile | 2 +- docs/build_with_installed_llvm.md | 4 ++-- docs/designs/ci-board-validation-guide.md | 2 +- 7 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 4b46f5aa43..a2cfb9d562 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,8 @@ # Build artifacts build/ -build-*/ build_plain/ build_plan/ install/ -install-*/ # TileLang ST standalone build outputs (see temp_docs/standalone_st.md) test/tilelang_st/npu/a5/src/st/build/ diff --git a/README.md b/README.md index 41d6e47386..3457d7eda6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## 1. 项目简介 (Introduction) -**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR LLVM21 VPTO 分支 (`TaoTao-real/llvm-project:feature-vpto-llvm21`)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 +**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR LLVM21 VPTO 分支 (`vpto-dev/llvm-project:feature-vpto-llvm21`)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 作为连接上层 AI 框架与底层各类NPU/GPGPU/CPU硬件,`ptoas` 采用 **Out-of-Tree** 架构构建,提供了完整的 C++ 与 Python 接口,主要职责包括: @@ -37,7 +37,7 @@ PTOAS/ ## 3. 构建指南 (Build Instructions) -⚠️ **重要提示**:本项目严格依赖 **LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21`**。 +⚠️ **重要提示**:本项目严格依赖 **LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21`**。 ### 3.0 环境变量配置 (Configuration) @@ -89,7 +89,7 @@ python3 -m pip install 'pybind11<3' nanobind numpy ```bash # 1. 下载 LLVM 源码 cd $WORKSPACE_DIR -git clone https://github.com/TaoTao-real/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd $LLVM_SOURCE_DIR # 2. [关键] 切换到 VPTO 适配分支 diff --git a/README_en.md b/README_en.md index c1a75c6a77..2832349715 100644 --- a/README_en.md +++ b/README_en.md @@ -2,7 +2,7 @@ ## 1. Introduction -**ptoas** is a specialized compiler toolchain built on top of the **LLVM21 VPTO branch (`TaoTao-real/llvm-project:feature-vpto-llvm21`)**, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). +**ptoas** is a specialized compiler toolchain built on top of the **LLVM21 VPTO branch (`vpto-dev/llvm-project:feature-vpto-llvm21`)**, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: @@ -36,7 +36,7 @@ PTOAS/ ## 3. Build Instructions -⚠️ **Important**: This project strictly requires the **LLVM21 VPTO branch `TaoTao-real/llvm-project:feature-vpto-llvm21`**. +⚠️ **Important**: This project strictly requires the **LLVM21 VPTO branch `vpto-dev/llvm-project:feature-vpto-llvm21`**. ### 3.0 Environment Variable Configuration @@ -84,7 +84,7 @@ Download the VPTO-adapted LLVM source, check out the `feature-vpto-llvm21` branc ```bash # 1. Clone LLVM cd $WORKSPACE_DIR -git clone https://github.com/TaoTao-real/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd $LLVM_SOURCE_DIR # 2. [Critical] Check out the VPTO adaptation branch diff --git a/ReleaseNotes.md b/ReleaseNotes.md index e584f1ff5e..3beaf43603 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -8,7 +8,7 @@ - PTOAS 首次发布 ## 概述 -PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21` 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 +PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21` 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 PTOAS很快将集成到以下框架中,敬请期待 - PyPTO diff --git a/docker/Dockerfile b/docker/Dockerfile index 9789c57edc..4929817947 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,7 @@ ARG ARCH # NOTE: change $PY_VER for different Python versions (3.8 - 3.14 available) ARG PY_VER=cp311-cp311 ARG LLVM_REF=feature-vpto-llvm21 -ARG LLVM_REPO=https://github.com/TaoTao-real/llvm-project.git +ARG LLVM_REPO=https://github.com/vpto-dev/llvm-project.git ## -- usually no need to change below -- diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md index 8e3f14575d..7b09173b27 100644 --- a/docs/build_with_installed_llvm.md +++ b/docs/build_with_installed_llvm.md @@ -2,7 +2,7 @@ 本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: -- LLVM/MLIR LLVM21 VPTO 分支 `TaoTao-real/llvm-project:feature-vpto-llvm21` 已经构建并安装完成。 +- LLVM/MLIR LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21` 已经构建并安装完成。 - LLVM 安装路径固定为 `/opt/llvm`。 - `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 @@ -66,7 +66,7 @@ README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 21.1.8 ``` -实际源码基线应来自 `https://github.com/TaoTao-real/llvm-project.git` 的 `feature-vpto-llvm21` 分支。 +实际源码基线应来自 `https://github.com/vpto-dev/llvm-project.git` 的 `feature-vpto-llvm21` 分支。 ## 3.3 第二步:构建 ptoas diff --git a/docs/designs/ci-board-validation-guide.md b/docs/designs/ci-board-validation-guide.md index 0678229827..bb963841d9 100644 --- a/docs/designs/ci-board-validation-guide.md +++ b/docs/designs/ci-board-validation-guide.md @@ -81,7 +81,7 @@ ### 3.1 构建 LLVM/MLIR ```bash -git clone https://github.com/TaoTao-real/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd llvm-project git checkout feature-vpto-llvm21 From 98592be09aa6ba51edf00e531353d7773dee86d4 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 18 Jun 2026 15:10:35 +0800 Subject: [PATCH 32/51] fix: restore entry contract and seam IR emission --- .../pto/materialize_tile_handles_control_flow_result.pto | 8 +++++++- tools/ptoas/ptoas.cpp | 4 ---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/lit/pto/materialize_tile_handles_control_flow_result.pto b/test/lit/pto/materialize_tile_handles_control_flow_result.pto index ee015bdba1..f6e965a5ba 100644 --- a/test/lit/pto/materialize_tile_handles_control_flow_result.pto +++ b/test/lit/pto/materialize_tile_handles_control_flow_result.pto @@ -1,4 +1,7 @@ -// RUN: ptoas --pto-arch=a3 --pto-print-seam-ir %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR +// RUN: rm -f %t.seam +// RUN: ptoas --pto-arch=a3 --pto-seam-ir-file=%t.seam %s -o %t.cpp +// RUN: FileCheck %s --check-prefix=IR < %t.seam +// RUN: FileCheck %s --check-prefix=SEAM-EMITC < %t.cpp // RUN: ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC module { @@ -63,6 +66,9 @@ module { // IR: } // IR: pto.tmul ins(%[[RES]], %[[RES]] +// SEAM-EMITC-LABEL: AICORE void control_flow_result_tile_materialization +// SEAM-EMITC: TADD( + // EMITC-LABEL: AICORE void control_flow_result_tile_materialization // EMITC: if ( // EMITC: TADD( diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index cd812d3d8c..369988d8aa 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -2157,10 +2157,6 @@ int mlir::pto::compilePTOASModule( printSharedPreBackendSeamIR(*module); if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) return 1; - if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { - result.kind = PTOASCompileResultKind::Text; - return 0; - } PassManager emitcPM(module->getContext()); emitcPM.enableVerifier(); From 67a0462d448bacf9a4890fa05549898bb481758f Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Mon, 22 Jun 2026 15:29:25 +0800 Subject: [PATCH 33/51] fix: adapt latest main A5 checks for LLVM 21 --- lib/PTO/IR/PTO.cpp | 6 +----- test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto | 12 ++++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index f314a9edad..50a05c16ee 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4964,11 +4964,7 @@ static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, return success(); auto isA5TMatmulFp8Type = [](Type ty) { - if (auto ft = mlir::dyn_cast(ty)) - return ft.isFloat8E4M3() || ft.isFloat8E4M3FN() || - ft.isFloat8E4M3FNUZ() || ft.isFloat8E4M3B11FNUZ() || - ft.isFloat8E5M2() || ft.isFloat8E5M2FNUZ(); - return false; + return isPTOFloat8Type(ty); }; if (isA5 && dstElemTy.isF32()) { if (isA5TMatmulFp8Type(lhsElemTy) && isA5TMatmulFp8Type(rhsElemTy)) diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto index 8562c12070..b4ec073938 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto @@ -84,23 +84,27 @@ module { // CHECK-LABEL: AICORE void cube_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY_VAL:v[0-9]+]] = [[CUBE_C2V_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY_VAL]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY_VAL]]); // CHECK-LABEL: AICORE void vector_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY_VAL:v[0-9]+]] = [[VEC_C2V_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK-LABEL: AICORE void vector_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY_VAL:v[0-9]+]] = [[VEC_V2C_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY_VAL]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY_VAL]]); // CHECK-LABEL: AICORE void cube_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY_VAL:v[0-9]+]] = [[CUBE_V2C_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( From 9fe8b24318a19159ffbfc5af769756cf77623558 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Mon, 22 Jun 2026 15:50:15 +0800 Subject: [PATCH 34/51] chore: keep python staging cmake unchanged --- lib/Bindings/Python/CMakeLists.txt | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index 16ebd8d710..4e484668bd 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -87,38 +87,18 @@ add_dependencies(_pto PTOPythonGen) # ---- 3) Copy generated python + handwritten pto.py into build/python ---- set(PTO_PY_SRC "${CMAKE_SOURCE_DIR}/python/pto/dialects/pto.py") -set(PTO_PY_BUILD_DIR "${CMAKE_BINARY_DIR}/python/mlir/dialects") -set(PTO_PY_BUILD "${PTO_PY_BUILD_DIR}/pto.py") -set(PTO_OPS_PY_BUILD "${PTO_PY_BUILD_DIR}/_pto_ops_gen.py") - -add_custom_command( - OUTPUT - "${PTO_PY_BUILD}" - "${PTO_OPS_PY_BUILD}" - COMMAND ${CMAKE_COMMAND} -E make_directory "${PTO_PY_BUILD_DIR}" + +add_custom_command(TARGET _pto POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/python/mlir/dialects" COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PTO_PY_SRC}" - "${PTO_PY_BUILD}" + "${CMAKE_BINARY_DIR}/python/mlir/dialects/pto.py" COMMAND ${CMAKE_COMMAND} -E copy_if_different "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" - "${PTO_OPS_PY_BUILD}" - DEPENDS - "${PTO_PY_SRC}" - "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" + "${CMAKE_BINARY_DIR}/python/mlir/dialects/_pto_ops_gen.py" VERBATIM ) -add_custom_target(PTOStagePythonModules ALL - DEPENDS - "${PTO_PY_BUILD}" - "${PTO_OPS_PY_BUILD}" -) -add_dependencies(PTOStagePythonModules PTOPythonGen) -add_dependencies(_pto PTOStagePythonModules) -if(TARGET PTOPythonModules) - add_dependencies(PTOPythonModules PTOStagePythonModules) -endif() - install(FILES "${PTO_PY_SRC}" "${CMAKE_CURRENT_BINARY_DIR}/_pto_ops_gen.py" From 70ae964dee201adffe51abe3d7b3cbe34a49875c Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Mon, 22 Jun 2026 17:04:03 +0800 Subject: [PATCH 35/51] fix: keep ptobc LLVM21 changes minimal --- tools/ptobc/src/mlir_encode.cpp | 30 ------------------- .../testdata/recent_mx_ops_v0_roundtrip.pto | 10 +++---- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/tools/ptobc/src/mlir_encode.cpp b/tools/ptobc/src/mlir_encode.cpp index 87bcd5f0f0..89c19fe84e 100644 --- a/tools/ptobc/src/mlir_encode.cpp +++ b/tools/ptobc/src/mlir_encode.cpp @@ -121,35 +121,6 @@ static mlir::DictionaryAttr stripKnownImmediateAttrs( } } -static bool hasDefaultZeroPipeId(llvm::StringRef opName) { - return llvm::StringSwitch(opName) - .Case("pto.aic_initialize_pipe", true) - .Case("pto.aiv_initialize_pipe", true) - .Case("pto.talloc_to_aiv", true) - .Case("pto.talloc_to_aic", true) - .Case("pto.tpush_to_aiv", true) - .Case("pto.tpush_to_aic", true) - .Case("pto.tpop_from_aiv", true) - .Case("pto.tpop_from_aic", true) - .Case("pto.tfree_from_aiv", true) - .Case("pto.tfree_from_aic", true) - .Default(false); -} - -static mlir::DictionaryAttr stripDefaultZeroPipeId(mlir::MLIRContext *ctx, - mlir::DictionaryAttr dict, - mlir::Operation &op) { - if (!dict || dict.empty() || - !hasDefaultZeroPipeId(op.getName().getStringRef())) - return dict; - - auto idAttr = op.getAttrOfType("id"); - if (!idAttr || idAttr.getInt() != 0) - return dict; - - return stripAttrs(ctx, dict, {"id"}); -} - static uint64_t internAttr(PTOBCFile& f, mlir::DictionaryAttr dict) { if (!dict || dict.empty()) return 0; std::string s = printAttrDict(dict); @@ -589,7 +560,6 @@ void Encoder::encodeKnownOp(mlir::Operation &op, Buffer &out, out.appendU16LE(variantInfo.opcode); mlir::DictionaryAttr dict = op.getAttrDictionary(); dict = stripKnownImmediateAttrs(op.getContext(), dict, info); - dict = stripDefaultZeroPipeId(op.getContext(), dict, op); writeULEB128(internAttr(file, dict), out.bytes); if (info.has_variant_u8) diff --git a/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto index 51b4ec2653..b54687d0f4 100644 --- a/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto @@ -9,18 +9,18 @@ module attributes {pto.target_arch = "a5"} { func.func @recent_mx_ops_v0() { %a = pto.alloc_tile : !pto.tile_buf - %a_scale = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf %b = pto.alloc_tile : !pto.tile_buf - %b_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf %c_in = pto.alloc_tile : !pto.tile_buf %bias = pto.alloc_tile : !pto.tile_buf %dst = pto.alloc_tile : !pto.tile_buf %dst_acc = pto.alloc_tile : !pto.tile_buf %dst_bias = pto.alloc_tile : !pto.tile_buf - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) - pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_acc : !pto.tile_buf) - pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_bias : !pto.tile_buf) + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_acc : !pto.tile_buf) + pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_bias : !pto.tile_buf) return } } From 4276dd4ae56273a97514ecbc243fba62e6e62bd1 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 24 Jun 2026 09:44:13 +0800 Subject: [PATCH 36/51] fix: adapt VPTO mad float8 checks for LLVM 21 --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 2 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 8de5c59b84..0451203982 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -348,7 +348,7 @@ static bool isMadE4M3ElementType(Type type) { } static bool isMadE5M2ElementType(Type type) { - return type.isFloat8E5M2() || type.isFloat8E5M2FNUZ(); + return pto::isPTOFloat8E5M2LikeType(type); } static std::string getMadDstFragment(Type type) { diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 75ed0cd48e..3c2ff6267c 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -349,7 +349,7 @@ static bool isMadE4M3ElementType(Type type) { } static bool isMadE5M2ElementType(Type type) { - return type.isFloat8E5M2() || type.isFloat8E5M2FNUZ(); + return pto::isPTOFloat8E5M2LikeType(type); } static std::string getMadDstFragment(Type type) { From 1edbae57fccb8cecbb9355c143c8fdc1e85a38b0 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 25 Jun 2026 14:51:54 +0800 Subject: [PATCH 37/51] fix: materialize tquant mx emitc pointer operands --- lib/PTO/Transforms/PTOToEmitC.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 1213fabacf..f73ae8733c 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -9968,8 +9968,20 @@ struct PTOQuantMxToEmitC : public OpConversionPattern { op, "expected all operands to be emitc::OpaqueType"); auto makePtr = [&](Value v, emitc::OpaqueType ot) -> Value { + if (auto lvalueTy = dyn_cast(v.getType())) + return rewriter.create( + loc, emitc::PointerType::get(lvalueTy.getValueType()), + "&", v) + .getResult(); + + Value tmp = rewriter + .create( + loc, getEmitCVariableResultType(ot), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + rewriter.create(loc, tmp, v); return rewriter.create( - loc, emitc::PointerType::get(ot), "&", v) + loc, emitc::PointerType::get(ot), "&", tmp) .getResult(); }; From 8f844821e5d08bc3a0e28a34a50e6117e0916eb8 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 26 Jun 2026 10:24:17 +0800 Subject: [PATCH 38/51] fix: address llvm21 rebase validation failures --- lib/PTO/IR/PTO.cpp | 9 +- lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp | 6 +- python/pto/dialects/pto.py | 127 ++++++++++++++++++ .../emitc_tile_data_sink_after_tassign.pto | 18 +-- ...alize_tile_handles_control_flow_result.pto | 6 +- tools/ptoas/ptoas.cpp | 3 +- tools/ptobc/src/mlir_encode.cpp | 33 +++++ 7 files changed, 183 insertions(+), 19 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 50a05c16ee..a5f2cee98a 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6615,7 +6615,11 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxFp8InputType(Type ty) { - return isa(ty); + std::string text; + llvm::raw_string_ostream os(text); + ty.print(os); + os.flush(); + return text == "f8E4M3FN" || text == "f8E5M2"; } static bool isA5MxInputTypePair(Type lhsTy, Type rhsTy) { @@ -14025,8 +14029,7 @@ static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { needsComma = true; }; - if (op.getId() != 0) - printClause("id", op.getId()); + printClause("id", op.getId()); printClause("dir_mask", static_cast(op.getDirMask())); printClause("slot_size", op.getSlotSize()); if (auto slotNumAttr = op.getSlotNumAttr()) diff --git a/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp b/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp index 3592b1dded..ee2d7f5c67 100644 --- a/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp +++ b/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp @@ -162,10 +162,10 @@ struct PTOUnrollSIMTFor : public pto::impl::PTOUnrollSIMTForBase(ctx, maxTripCount); GreedyRewriteConfig config; - config.maxIterations = 10; // loops may nest - config.strictMode = GreedyRewriteStrictness::ExistingOps; + config.setMaxIterations(10); // loops may nest + config.setStrictness(GreedyRewriteStrictness::ExistingOps); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config))) + if (failed(applyPatternsGreedily(func, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 908c089d8a..0774adb426 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -986,6 +986,133 @@ def _matches_src_type(value): __all__.append("TScatter") +_GeneratedMteL1L0aMxOp = MteL1L0aMxOp +_GeneratedMteL1L0bMxOp = MteL1L0bMxOp +_MX_OFFSET_UNSET = object() + + +def _require_mx_offset(name, value): + if value is _MX_OFFSET_UNSET: + raise TypeError(f"missing required argument: {name}") + return value + + +class MteL1L0aMxOp(_GeneratedMteL1L0aMxOp): + """Compatibility wrapper for LLVM21 Python bindings missing MX offsets.""" + + def __init__( + self, + source, + destination, + m, + k, + start_row=_MX_OFFSET_UNSET, + start_col=_MX_OFFSET_UNSET, + *, + loc=None, + ip=None, + ): + operands = [ + _get_op_result_or_value(source), + _get_op_result_or_value(destination), + _get_op_result_or_value(m), + _get_op_result_or_value(k), + _get_op_result_or_value(_require_mx_offset("start_row", start_row)), + _get_op_result_or_value(_require_mx_offset("start_col", start_col)), + ] + op = self.build_generic( + attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip, + ) + _ods_ir.OpView.__init__(self, op) + + @property + def start_row(self): + return self.operation.operands[4] + + @property + def start_col(self): + return self.operation.operands[5] + + +def mte_l1_l0a_mx( + source, destination, m, k, start_row, start_col, *, loc=None, ip=None +): + return MteL1L0aMxOp( + source=source, + destination=destination, + m=m, + k=k, + start_row=start_row, + start_col=start_col, + loc=loc, + ip=ip, + ) + + +class MteL1L0bMxOp(_GeneratedMteL1L0bMxOp): + """Compatibility wrapper for LLVM21 Python bindings missing MX offsets.""" + + def __init__( + self, + source, + destination, + k, + n, + start_row=_MX_OFFSET_UNSET, + start_col=_MX_OFFSET_UNSET, + *, + loc=None, + ip=None, + ): + operands = [ + _get_op_result_or_value(source), + _get_op_result_or_value(destination), + _get_op_result_or_value(k), + _get_op_result_or_value(n), + _get_op_result_or_value(_require_mx_offset("start_row", start_row)), + _get_op_result_or_value(_require_mx_offset("start_col", start_col)), + ] + op = self.build_generic( + attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip, + ) + _ods_ir.OpView.__init__(self, op) + + @property + def start_row(self): + return self.operation.operands[4] + + @property + def start_col(self): + return self.operation.operands[5] + + +def mte_l1_l0b_mx( + source, destination, k, n, start_row, start_col, *, loc=None, ip=None +): + return MteL1L0bMxOp( + source=source, + destination=destination, + k=k, + n=n, + start_row=start_row, + start_col=start_col, + loc=loc, + ip=ip, + ) + + # ----------------------------------------------------------------------------- # Op aliases without "Op" suffix (user-facing) # ----------------------------------------------------------------------------- diff --git a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto index 2c946df121..72c58c9ed9 100644 --- a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto +++ b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto @@ -25,11 +25,13 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind(v4); -// CHECK: TASSIGN(v3, v5); -// CHECK: __ubuf__ float* v6 = v3.data(); -// CHECK: sink_ptr(v6); +// CHECK: Tile [[ADDR_BASE:v[0-9]+]]; +// CHECK-NEXT: Tile [[ADDR_TILE:v[0-9]+]] = [[ADDR_BASE]]; +// CHECK-NEXT: TASSIGN([[ADDR_TILE]], v1); +// CHECK: Tile [[SINK_BASE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SINK_TILE:v[0-9]+]] = [[SINK_BASE]]; +// CHECK: __ubuf__ float* [[SINK_PTR:v[0-9]+]] = [[SINK_TILE]].data(); +// CHECK: __ubuf__ float* [[ADDR_PTR:v[0-9]+]] = [[ADDR_TILE]].data(); +// CHECK: uint64_t [[ADDR_BITS:v[0-9]+]] = reinterpret_cast([[ADDR_PTR]]); +// CHECK: TASSIGN([[SINK_TILE]], [[ADDR_BITS]]); +// CHECK: sink_ptr([[SINK_PTR]]); diff --git a/test/lit/pto/materialize_tile_handles_control_flow_result.pto b/test/lit/pto/materialize_tile_handles_control_flow_result.pto index f6e965a5ba..f927d47f86 100644 --- a/test/lit/pto/materialize_tile_handles_control_flow_result.pto +++ b/test/lit/pto/materialize_tile_handles_control_flow_result.pto @@ -51,12 +51,12 @@ module { } } -// IR: %[[SEL:.*]] = scf.if {{.*}} -> (!pto.tile_buf) { +// IR-LABEL: func.func @control_flow_result_tile_materialization +// IR: %[[SEL:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : !pto.tile_buf +// IR: scf.if {{.*}} { // IR: pto.tadd -// IR: scf.yield %{{.*}} : !pto.tile_buf // IR: } else { // IR: pto.tmul -// IR: scf.yield %{{.*}} : !pto.tile_buf // IR: } // IR: pto.tadd ins(%[[SEL]], %[[SEL]] // IR-LABEL: func.func @loop_carried_tile_materialization diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 369988d8aa..7f54853916 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1876,8 +1876,7 @@ int mlir::pto::compilePTOASModule( char **argv = context.getArgv(); if (effectiveBackend != PTOBackend::VPTO && - (emitVPTO || emitVPTOLLVMDialect || ptoPrintSeamIR || - !ptoSeamIRFile.empty())) { + (emitVPTO || emitVPTOLLVMDialect)) { llvm::errs() << "Error: VPTO-specific flags require " "--pto-backend=vpto or pto.backend = \"vpto\".\n"; return 1; diff --git a/tools/ptobc/src/mlir_encode.cpp b/tools/ptobc/src/mlir_encode.cpp index 89c19fe84e..a4cca9e356 100644 --- a/tools/ptobc/src/mlir_encode.cpp +++ b/tools/ptobc/src/mlir_encode.cpp @@ -121,6 +121,36 @@ static mlir::DictionaryAttr stripKnownImmediateAttrs( } } +static bool isFrontendPipeOpWithDefaultId(mlir::Operation &op) { + llvm::StringRef name = op.getName().getStringRef(); + return name == "pto.aic_initialize_pipe" || + name == "pto.aiv_initialize_pipe" || + name == "pto.talloc_to_aiv" || name == "pto.talloc_to_aic" || + name == "pto.tpush_to_aiv" || name == "pto.tpush_to_aic" || + name == "pto.tpop_from_aic" || name == "pto.tpop_from_aiv" || + name == "pto.tfree_from_aic" || name == "pto.tfree_from_aiv"; +} + +static mlir::DictionaryAttr +dropDefaultZeroPipeIdForV0Encoding(mlir::Operation &op, + mlir::DictionaryAttr dict) { + if (!dict || dict.empty() || !isFrontendPipeOpWithDefaultId(op)) + return dict; + + auto idAttr = dict.getAs("id"); + if (!idAttr || idAttr.getInt() != 0) + return dict; + + llvm::SmallVector keep; + keep.reserve(dict.size()); + for (auto attr : dict) { + if (attr.getName() == "id") + continue; + keep.push_back(attr); + } + return mlir::DictionaryAttr::get(op.getContext(), keep); +} + static uint64_t internAttr(PTOBCFile& f, mlir::DictionaryAttr dict) { if (!dict || dict.empty()) return 0; std::string s = printAttrDict(dict); @@ -560,6 +590,9 @@ void Encoder::encodeKnownOp(mlir::Operation &op, Buffer &out, out.appendU16LE(variantInfo.opcode); mlir::DictionaryAttr dict = op.getAttrDictionary(); dict = stripKnownImmediateAttrs(op.getContext(), dict, info); + // The assembly printer omits default pipe ids on transfer/free ops. Keep v0 + // byte roundtrips stable by using the same representation while encoding. + dict = dropDefaultZeroPipeIdForV0Encoding(op, dict); writeULEB128(internAttr(file, dict), out.bytes); if (info.has_variant_u8) From 5e69dbb919fc790db6932816d2fc8d4f0f7e596b Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 26 Jun 2026 10:37:08 +0800 Subject: [PATCH 39/51] fix: rely on generated fusion plan options ctor --- lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp index b52f5cff50..cacd31565c 100644 --- a/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp +++ b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp @@ -528,9 +528,6 @@ struct FusionPlanPass : public pto::impl::FusionPlanBase { using pto::impl::FusionPlanBase::FusionPlanBase; FusionPlanPass() = default; - FusionPlanPass(const pto::FusionPlanOptions &options) { - enableShapeInference = options.enableShapeInference; - } void runOnOperation() override { func::FuncOp func = getOperation(); From e5913eb01eeeb33e61f6422b1b6091161eab664d Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 26 Jun 2026 10:40:50 +0800 Subject: [PATCH 40/51] ci: ignore unrelated microsoft apt sources --- .github/workflows/ci.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12012cdb46..ce02ff3378 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,6 +114,15 @@ jobs: - name: Install dependencies run: | + # GitHub-hosted runners may include Microsoft apt sources unrelated to + # PTOAS; these can intermittently return 403 and break apt-get update. + sudo rm -f \ + /etc/apt/sources.list.d/azure-cli*.list \ + /etc/apt/sources.list.d/azure-cli*.sources \ + /etc/apt/sources.list.d/microsoft-prod*.list \ + /etc/apt/sources.list.d/microsoft-prod*.sources \ + /etc/apt/sources.list.d/packages-microsoft-prod*.list \ + /etc/apt/sources.list.d/packages-microsoft-prod*.sources sudo apt-get update sudo apt-get install -y \ cmake git ninja-build \ From 67bb0ba2f66f79b026ea7f81f5d419c1c71c430a Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sat, 27 Jun 2026 16:48:21 +0800 Subject: [PATCH 41/51] fix: preserve simt callsite annotation check --- lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp index 73919b9ae8..8b5eb7a165 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -569,14 +569,13 @@ void attachHIVMKernelAnnotations(llvm::Module &llvmModule, simtConfigByName[symName] = {maxThreads, maxRegisters}; }); - auto callsSimtEntry = [&](llvm::Function &function) { + auto callsSimtEntry = [](llvm::Function &function) { for (llvm::BasicBlock &block : function) { for (llvm::Instruction &inst : block) { auto *call = llvm::dyn_cast(&inst); if (!call) continue; - auto *callee = call->getCalledFunction(); - if (callee && simtConfigByName.contains(callee->getName())) + if (call->getCallingConv() == llvm::CallingConv::SimtEntry) return true; } } From 3eab87c3e5c63ab2fec01703a755688131010b10 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sat, 27 Jun 2026 17:34:34 +0800 Subject: [PATCH 42/51] fix: rely on generated MX python bindings --- python/pto/dialects/pto.py | 127 ------------------------------------- 1 file changed, 127 deletions(-) diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 0774adb426..908c089d8a 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -986,133 +986,6 @@ def _matches_src_type(value): __all__.append("TScatter") -_GeneratedMteL1L0aMxOp = MteL1L0aMxOp -_GeneratedMteL1L0bMxOp = MteL1L0bMxOp -_MX_OFFSET_UNSET = object() - - -def _require_mx_offset(name, value): - if value is _MX_OFFSET_UNSET: - raise TypeError(f"missing required argument: {name}") - return value - - -class MteL1L0aMxOp(_GeneratedMteL1L0aMxOp): - """Compatibility wrapper for LLVM21 Python bindings missing MX offsets.""" - - def __init__( - self, - source, - destination, - m, - k, - start_row=_MX_OFFSET_UNSET, - start_col=_MX_OFFSET_UNSET, - *, - loc=None, - ip=None, - ): - operands = [ - _get_op_result_or_value(source), - _get_op_result_or_value(destination), - _get_op_result_or_value(m), - _get_op_result_or_value(k), - _get_op_result_or_value(_require_mx_offset("start_row", start_row)), - _get_op_result_or_value(_require_mx_offset("start_col", start_col)), - ] - op = self.build_generic( - attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip, - ) - _ods_ir.OpView.__init__(self, op) - - @property - def start_row(self): - return self.operation.operands[4] - - @property - def start_col(self): - return self.operation.operands[5] - - -def mte_l1_l0a_mx( - source, destination, m, k, start_row, start_col, *, loc=None, ip=None -): - return MteL1L0aMxOp( - source=source, - destination=destination, - m=m, - k=k, - start_row=start_row, - start_col=start_col, - loc=loc, - ip=ip, - ) - - -class MteL1L0bMxOp(_GeneratedMteL1L0bMxOp): - """Compatibility wrapper for LLVM21 Python bindings missing MX offsets.""" - - def __init__( - self, - source, - destination, - k, - n, - start_row=_MX_OFFSET_UNSET, - start_col=_MX_OFFSET_UNSET, - *, - loc=None, - ip=None, - ): - operands = [ - _get_op_result_or_value(source), - _get_op_result_or_value(destination), - _get_op_result_or_value(k), - _get_op_result_or_value(n), - _get_op_result_or_value(_require_mx_offset("start_row", start_row)), - _get_op_result_or_value(_require_mx_offset("start_col", start_col)), - ] - op = self.build_generic( - attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip, - ) - _ods_ir.OpView.__init__(self, op) - - @property - def start_row(self): - return self.operation.operands[4] - - @property - def start_col(self): - return self.operation.operands[5] - - -def mte_l1_l0b_mx( - source, destination, k, n, start_row, start_col, *, loc=None, ip=None -): - return MteL1L0bMxOp( - source=source, - destination=destination, - k=k, - n=n, - start_row=start_row, - start_col=start_col, - loc=loc, - ip=ip, - ) - - # ----------------------------------------------------------------------------- # Op aliases without "Op" suffix (user-facing) # ----------------------------------------------------------------------------- From 06ab30a1085958575371de18cb5646f936f5679f Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sun, 28 Jun 2026 13:51:58 +0800 Subject: [PATCH 43/51] fix: support vpto memref stride API variants --- include/PTO/IR/PTOTypeUtils.h | 22 +++++++++++++++++++ lib/PTO/IR/PTO.cpp | 4 ++-- lib/PTO/Transforms/ExpandTileOp.cpp | 6 +++-- lib/PTO/Transforms/InferPTOLayout.cpp | 6 +++-- .../InsertSync/InsertSyncAnalysis.cpp | 3 ++- .../Transforms/InsertSync/PTOIRTranslator.cpp | 9 +++++--- .../Transforms/PTOMaterializeTileHandles.cpp | 7 +++--- lib/PTO/Transforms/PTOToEmitC.cpp | 13 ++++++----- lib/PTO/Transforms/PTOViewToMemref.cpp | 6 +++-- lib/PTO/Transforms/VPTOPtrNormalize.cpp | 4 +++- 10 files changed, 59 insertions(+), 21 deletions(-) diff --git a/include/PTO/IR/PTOTypeUtils.h b/include/PTO/IR/PTOTypeUtils.h index 45d0740487..f060c03912 100644 --- a/include/PTO/IR/PTOTypeUtils.h +++ b/include/PTO/IR/PTOTypeUtils.h @@ -9,11 +9,33 @@ #ifndef PTO_IR_PTOTYPEUTILS_H #define PTO_IR_PTOTYPEUTILS_H +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::pto { +namespace detail { +template +inline auto getPTOMemRefStridesAndOffsetImpl( + MemRefT memTy, SmallVectorImpl &strides, int64_t &offset, int) + -> decltype(memTy.getStridesAndOffset(strides, offset)) { + return memTy.getStridesAndOffset(strides, offset); +} + +template +inline LogicalResult getPTOMemRefStridesAndOffsetImpl( + MemRefT memTy, SmallVectorImpl &strides, int64_t &offset, long) { + return getStridesAndOffset(memTy, strides, offset); +} +} // namespace detail + +inline LogicalResult getPTOMemRefStridesAndOffset( + MemRefType memTy, SmallVectorImpl &strides, int64_t &offset) { + return detail::getPTOMemRefStridesAndOffsetImpl(memTy, strides, offset, 0); +} + bool isPTOFloat8Type(Type t); bool isPTOFloat8E4M3LikeType(Type t); bool isPTOFloat8E5M2LikeType(Type t); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a5f2cee98a..94aeb52bab 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3655,7 +3655,7 @@ static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, SmallVector strides; int64_t offset = 0; - if (failed(memTy.getStridesAndOffset(strides, offset))) + if (failed(getPTOMemRefStridesAndOffset(memTy, strides, offset))) return op->emitOpError() << "expects " << name << " to be a strided memref with a known layout"; @@ -12769,7 +12769,7 @@ mlir::LogicalResult mlir::pto::SimdTileToMemrefOp::verify() { SmallVector memStrides; int64_t memOffset = ShapedType::kDynamic; - if (failed(memTy.getStridesAndOffset(memStrides, memOffset))) + if (failed(getPTOMemRefStridesAndOffset(memTy, memStrides, memOffset))) return emitOpError("expects memref to use strided layout"); if (memOffset != 0) return emitOpError("expects memref offset to be 0"); diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index fabca3d29c..28765794e1 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -598,7 +598,8 @@ static void populateViewShapeAndStrides(Value value, shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); if (strides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(memrefTy.getStridesAndOffset(strides, offset))) { + if (succeeded( + mlir::pto::getPTOMemRefStridesAndOffset(memrefTy, strides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } @@ -646,7 +647,8 @@ static std::optional buildOperandTypeInfo(Value value) { info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); if (info.viewStrides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(mrTy.getStridesAndOffset(info.viewStrides, offset))) { + if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + mrTy, info.viewStrides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index b25e37d7a3..572995a870 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -288,7 +288,8 @@ static std::optional inferFromStaticMemRefTy(MemRefType mrTy) { return std::nullopt; SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(mrTy.getStridesAndOffset(strideInts, offset))) + if (failed( + mlir::pto::getPTOMemRefStridesAndOffset(mrTy, strideInts, offset))) return std::nullopt; if (offset == ShapedType::kDynamic || llvm::any_of(strideInts, @@ -633,7 +634,8 @@ struct InferPTOLayoutPass SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(srcTy.getStridesAndOffset(strideInts, offset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcTy, strideInts, + offset)) || offset == ShapedType::kDynamic || llvm::any_of(strideInts, [](int64_t s) { return s == ShapedType::kDynamic; })) { diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 564094de20..3853179427 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -113,7 +113,8 @@ static std::optional getKnownBLayout(Type ty) { if (auto memRefTy = dyn_cast(ty)) { SmallVector strides; int64_t offset = 0; - if (failed(memRefTy.getStridesAndOffset(strides, offset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(memRefTy, strides, + offset)) || strides.size() != 2) { return std::nullopt; } diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 63d8b34eca..f6c2bcf710 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -153,7 +153,8 @@ getMemrefSubViewBaseAddresses(memref::SubViewOp op, MemRefType sourceType, SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(sourceType.getStridesAndOffset(strides, baseOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset)) || strides.size() != 2 || llvm::is_contained(strides, ShapedType::kDynamic)) return std::nullopt; @@ -257,7 +258,8 @@ static std::pair getStaticOffsetAndSize(Operation *op, Value s if (auto subView = dyn_cast(op)) { int64_t baseOffset; StrideVector strides; - if (failed(srcType.getStridesAndOffset(strides, baseOffset))) { + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcType, strides, + baseOffset))) { return {-1, -1}; } @@ -944,7 +946,8 @@ void PTOIRTranslator::UpdateMemrefSubViewAliasBufferInfo(memref::SubViewOp op) { SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(sourceType.getStridesAndOffset(strides, baseOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset)) || strides.size() != 2) { UpdateConservativeAliasBufferInfo(result, source); return; diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index d57b3a7740..a2560791b1 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -366,8 +366,8 @@ getMaterializedTileShape(MemRefType memTy, const TileHandleMetadata &meta) { SmallVector inheritedStrides; int64_t inheritedOffset = ShapedType::kDynamic; - if (failed(sourceMrTy.getStridesAndOffset(inheritedStrides, - inheritedOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset( + sourceMrTy, inheritedStrides, inheritedOffset)) || inheritedStrides.size() < 2) return shape; @@ -481,7 +481,8 @@ static Value computeSubviewAddress(memref::SubViewOp subview, SmallVector sourceStrides; int64_t sourceOffset = ShapedType::kDynamic; - if (failed(sourceTy.getStridesAndOffset(sourceStrides, sourceOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset( + sourceTy, sourceStrides, sourceOffset))) return Value(); auto mixedOffsets = subview.getMixedOffsets(); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f73ae8733c..bbde3dceb7 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -437,7 +437,8 @@ getGatherScatterShapeLayoutInfo(Type ty) { SmallVector strides; int64_t offset = ShapedType::kDynamic; - if (failed(memRefTy.getStridesAndOffset(strides, offset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(memRefTy, strides, + offset)) || strides.size() != 2) return std::nullopt; @@ -3592,7 +3593,8 @@ struct SubviewToEmitCPattern : public OpConversionPattern { SmallVector strideInts; int64_t offset = ShapedType::kDynamic; bool useTypeStrides = - succeeded(srcType.getStridesAndOffset(strideInts, offset)); + succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + srcType, strideInts, offset)); (void)offset; if (useTypeStrides) { for (int64_t s : strideInts) { @@ -3963,7 +3965,8 @@ static bool hasStaticShape(MemRefType mrTy) { static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, int64_t &offset) { - if (failed(mrTy.getStridesAndOffset(strides, offset))) { + if (failed( + mlir::pto::getPTOMemRefStridesAndOffset(mrTy, strides, offset))) { strides.clear(); int64_t stride = 1; ArrayRef shape = mrTy.getShape(); @@ -12122,8 +12125,8 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (!pto::isPTOFloat4PackedType(elemTy) && subRows != ShapedType::kDynamic && subCols != ShapedType::kDynamic && - succeeded(subMrTy.getStridesAndOffset(inheritedStrides, - inheritedOffset)) && + succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + subMrTy, inheritedStrides, inheritedOffset)) && inheritedStrides.size() >= 2) { int64_t childRowStride = 0; int64_t childColStride = 0; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 941a889924..3b8669c0db 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1442,7 +1442,8 @@ static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { SmallVector srcStrides; int64_t srcOffset = ShapedType::kDynamic; - if (failed(srcMrTy.getStridesAndOffset(srcStrides, srcOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcMrTy, srcStrides, + srcOffset))) srcStrides = computeCompactStrides(srcMrTy.getShape()); // Keep parent physical shape + strides for bound tile semantics. @@ -1996,7 +1997,8 @@ struct PTOViewToMemrefPass SmallVector staticStrides; int64_t offset = ShapedType::kDynamic; - if (succeeded(mrTy.getStridesAndOffset(staticStrides, offset)) && + if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + mrTy, staticStrides, offset)) && dimIndex < (int64_t)staticStrides.size() && staticStrides[dimIndex] != ShapedType::kDynamic) { rewriter.replaceOpWithNewOp( diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index 7ca37cbb40..c9bfd36fe9 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -10,6 +10,7 @@ #pragma GCC diagnostic ignored "-Woverloaded-virtual" #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -117,7 +118,8 @@ static LogicalResult computeSubviewElementOffset(memref::SubViewOp op, SmallVector strides; int64_t baseOffset = 0; - if (failed(sourceType.getStridesAndOffset(strides, baseOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset))) return failure(); // The SSA source already names the base address after bind_tile/pointer_cast // normalization. A dynamic memref layout offset here is metadata we can From 5c3d53ebda34b90e8018a45ceda1d69b6facbdbf Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Mon, 29 Jun 2026 10:11:13 +0800 Subject: [PATCH 44/51] build: require LLVM 21 at configure time --- CMakeLists.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4908c3cbb1..6e44eb2f5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,15 @@ find_package(LLVM REQUIRED CONFIG) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "LLVM CMake Dir: ${LLVM_CMAKE_DIR}") message(STATUS "LLVM Include Dir: ${LLVM_INCLUDE_DIRS}") +if(NOT DEFINED LLVM_VERSION_MAJOR) + string(REGEX MATCH "^[0-9]+" LLVM_VERSION_MAJOR "${LLVM_PACKAGE_VERSION}") +endif() +if(NOT LLVM_VERSION_MAJOR STREQUAL "21") + message(FATAL_ERROR + "PTOAS requires LLVM 21 after the LLVM21 upgrade, but found " + "LLVM ${LLVM_PACKAGE_VERSION}. Please point LLVM_DIR/MLIR_DIR to an " + "LLVM 21 build, for example vpto-dev/llvm-project:feature-vpto-llvm21.") +endif() get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) From 1960a995ad8f92db2f7d52fb61bd40ddc28421bb Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Mon, 29 Jun 2026 15:00:43 +0800 Subject: [PATCH 45/51] ci: harden llvm21 validation jobs --- .github/workflows/build_wheel.yml | 3 +++ .github/workflows/ci_sim.yml | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index c38015e781..ddc51116c4 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -162,6 +162,9 @@ jobs: - name: Build PTOAS run: | export PATH="${PY_PATH}/bin:$PATH" + if [ "${{ matrix.arch }}" = "aarch64" ]; then + export CMAKE_BUILD_PARALLEL_LEVEL=2 + fi PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ pip install . --no-build-isolation diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 7396c9a12d..20862b519e 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -316,6 +316,7 @@ jobs: ref: ${{ env.PYPTO_REF }} path: ${{ env.PYPTO_WORKSPACE }} fetch-depth: 1 + submodules: recursive persist-credentials: false - name: Checkout PTO-ISA @@ -385,6 +386,29 @@ jobs: # auto-detecting onboard platforms from the runner's CANN environment. env -u ASCEND_HOME_PATH python3 -m pip install --no-build-isolation --no-deps "${PYPTO_WORKSPACE}/runtime" + python3 - <<'PY' + from pathlib import Path + + from simpler_setup.environment import PROJECT_ROOT + from simpler_setup.kernel_compiler import KernelCompiler + + required = PROJECT_ROOT / "src" / "common" / "task_interface" / "arg_direction.h" + if not required.is_file(): + raise SystemExit(f"missing required simpler runtime header: {required}") + + include_dirs = KernelCompiler(platform="a5sim").get_orchestration_include_dirs( + "tensormap_and_ringbuffer" + ) + if str(required.parent) not in include_dirs: + raise SystemExit( + "simpler orchestration include dirs do not contain " + f"{required.parent}: {include_dirs}" + ) + + print(f"simpler PROJECT_ROOT={PROJECT_ROOT}") + print(f"simpler arg_direction.h={required}") + PY + - name: Run PyPTO PTOAS end-to-end simulator smoke shell: bash env: From 70a7cb54e61cf555ce23bc2fc425f99ca9fe9a19 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 1 Jul 2026 10:20:03 +0800 Subject: [PATCH 46/51] build: inherit LLVM compiler for PTOAS configure --- CMakeLists.txt | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e44eb2f5f..f857052bbc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,67 @@ # See LICENSE in the root of the software repository for the full text of the License. cmake_minimum_required(VERSION 3.20.0) + +function(ptoas_get_cmake_cache_value cache_file variable out_var) + if(NOT EXISTS "${cache_file}") + set(${out_var} "" PARENT_SCOPE) + return() + endif() + + file(STRINGS "${cache_file}" cache_lines + REGEX "^${variable}:[^=]*=") + if(cache_lines) + list(GET cache_lines 0 cache_line) + string(REGEX REPLACE "^[^=]*=" "" cache_value "${cache_line}") + set(${out_var} "${cache_value}" PARENT_SCOPE) + else() + set(${out_var} "" PARENT_SCOPE) + endif() +endfunction() + +function(ptoas_get_llvm_build_dir_from_llvm_dir out_var) + if(NOT DEFINED LLVM_DIR OR LLVM_DIR STREQUAL "") + set(${out_var} "" PARENT_SCOPE) + return() + endif() + + get_filename_component(llvm_cmake_dir "${LLVM_DIR}" REALPATH) + get_filename_component(llvm_build_dir "${llvm_cmake_dir}/../../.." REALPATH) + if(EXISTS "${llvm_build_dir}/CMakeCache.txt") + set(${out_var} "${llvm_build_dir}" PARENT_SCOPE) + else() + set(${out_var} "" PARENT_SCOPE) + endif() +endfunction() + +function(ptoas_preseed_compilers_from_llvm_cache) + ptoas_get_llvm_build_dir_from_llvm_dir(llvm_build_dir) + if(llvm_build_dir STREQUAL "") + return() + endif() + + set(llvm_cache "${llvm_build_dir}/CMakeCache.txt") + if(NOT DEFINED CMAKE_C_COMPILER AND "$ENV{CC}" STREQUAL "") + ptoas_get_cmake_cache_value("${llvm_cache}" "CMAKE_C_COMPILER" + llvm_c_compiler) + if(llvm_c_compiler AND EXISTS "${llvm_c_compiler}") + set(CMAKE_C_COMPILER "${llvm_c_compiler}" CACHE FILEPATH + "C compiler inherited from the configured LLVM build" FORCE) + endif() + endif() + + if(NOT DEFINED CMAKE_CXX_COMPILER AND "$ENV{CXX}" STREQUAL "") + ptoas_get_cmake_cache_value("${llvm_cache}" "CMAKE_CXX_COMPILER" + llvm_cxx_compiler) + if(llvm_cxx_compiler AND EXISTS "${llvm_cxx_compiler}") + set(CMAKE_CXX_COMPILER "${llvm_cxx_compiler}" CACHE FILEPATH + "C++ compiler inherited from the configured LLVM build" FORCE) + endif() + endif() +endfunction() + +ptoas_preseed_compilers_from_llvm_cache() + project(ptoas VERSION 0.47) set(PTOAS_RELEASE_VERSION_OVERRIDE "" @@ -63,6 +124,30 @@ if(NOT LLVM_VERSION_MAJOR STREQUAL "21") endif() get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) +ptoas_get_llvm_build_dir_from_llvm_dir(PTO_LLVM_BUILD_DIR) +if(PTO_LLVM_BUILD_DIR) + ptoas_get_cmake_cache_value("${PTO_LLVM_BUILD_DIR}/CMakeCache.txt" + "CMAKE_CXX_COMPILER" + PTO_LLVM_CXX_COMPILER) +endif() +if(PTO_LLVM_CXX_COMPILER AND EXISTS "${PTO_LLVM_CXX_COMPILER}") + execute_process( + COMMAND "${PTO_LLVM_CXX_COMPILER}" --version + OUTPUT_VARIABLE PTO_LLVM_CXX_VERSION_TEXT + ERROR_VARIABLE PTO_LLVM_CXX_VERSION_TEXT + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE) + if(PTO_LLVM_CXX_VERSION_TEXT MATCHES "[Cc]lang" + AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + message(FATAL_ERROR + "PTOAS is being configured with GNU C++ (${CMAKE_CXX_COMPILER}), " + "but the selected LLVM build was compiled with a Clang-compatible " + "compiler (${PTO_LLVM_CXX_COMPILER}). MLIR TypeID fallback strings are " + "compiler-family sensitive across shared libraries; use " + "-DCMAKE_CXX_COMPILER=${PTO_LLVM_CXX_COMPILER} or leave the compiler " + "unset so PTOAS can inherit it from LLVM_DIR.") + endif() +endif() # 将 LLVM 模块路径加入 CMake 搜索路径 list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") From 2e74ec1f51e05e9c7e6d96cedc152de6dda5ccd3 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 1 Jul 2026 11:07:10 +0800 Subject: [PATCH 47/51] test: align llvm21 lit expectations --- test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto | 10 ++++++---- test/lit/pto/tci_ui16_emitc.pto | 4 +--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto index d35962be36..a676cd3dda 100644 --- a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto +++ b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto @@ -86,12 +86,14 @@ module attributes {pto.target_arch = "a5"} { } } +// CHECK-LABEL: AICORE void mgather_emitc +// CHECK-DAG: __ubuf__ int32_t* [[MGATHER_IDX:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ float* [[MGATHER_DST:v[0-9]+]] = {{.*}}; // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); -// CHECK: __ubuf__ int32_t* [[MGATHER_IDX:v[0-9]+]] = {{.*}}; -// CHECK: __ubuf__ float* [[MGATHER_DST:v[0-9]+]] = {{.*}}; // CHECK: MGATHER([[MGATHER_DST]], {{v[0-9]+}}, [[MGATHER_IDX]]); +// CHECK-LABEL: AICORE void mscatter_emitc +// CHECK-DAG: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); -// CHECK: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; -// CHECK: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; // CHECK: MSCATTER({{v[0-9]+}}, [[MSCATTER_SRC]], [[MSCATTER_IDX]]); diff --git a/test/lit/pto/tci_ui16_emitc.pto b/test/lit/pto/tci_ui16_emitc.pto index e5ee180e0b..17c7eec7f1 100644 --- a/test/lit/pto/tci_ui16_emitc.pto +++ b/test/lit/pto/tci_ui16_emitc.pto @@ -5,9 +5,7 @@ module { %c7_i16 = arith.constant 7 : i16 %c7_ui16 = builtin.unrealized_conversion_cast %c7_i16 : i16 to ui16 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c16], strides: [%c16, %c1] {layout = #pto.layout} : memref<16xui16, #pto.address_space> to memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 16], strides: [16, 1] {layout = #pto.layout} : memref<16xui16, #pto.address_space> to memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c7_ui16 : ui16) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} From 2e20a1a2ee83c0d411cd2f70febc235d42ebd9b2 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 1 Jul 2026 11:28:26 +0800 Subject: [PATCH 48/51] test: accept llvm21 tile lvalue emitc form --- test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto index 31f8d27073..eab2597534 100644 --- a/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto +++ b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto @@ -75,7 +75,8 @@ module { // EMITC: AICORE void const_preload_dyn_loop_select // EMITC: for ( -// EMITC-NEXT: Tile [[TILE:v[0-9]+]]; +// EMITC-NEXT: Tile [[TILE_STORAGE:v[0-9]+]]; +// EMITC-NEXT: Tile [[TILE:v[0-9]+]] = [[TILE_STORAGE]]; // EMITC-NEXT: uint64_t [[ADDR:v[0-9]+]] = (uint64_t) // EMITC-NEXT: TASSIGN([[TILE]], [[ADDR]]); // EMITC: TLOAD([[TILE]], From 6c3dcc97874e46585932346bd6e924fc0a742773 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 1 Jul 2026 15:58:03 +0800 Subject: [PATCH 49/51] fix: address LLVM21 review comments --- lib/PTO/IR/PTO.cpp | 6 +--- lib/PTO/IR/PTOTypeDefs.cpp | 23 +++++++++++- lib/PTO/Transforms/ExpandTileOp.cpp | 4 +-- .../Transforms/PTOMaterializeTileHandles.cpp | 30 ++++++++++++++++ lib/PTO/Transforms/PTOToEmitC.cpp | 36 +++++++++++++++++-- .../pto/compact_left_blayout_parser_a5.pto | 2 +- .../emitc_tile_data_sink_after_tassign.pto | 10 +++--- test/lit/pto/left_blayout_parser_a3.pto | 4 +-- test/lit/pto/left_blayout_parser_a5.pto | 2 +- ...alize_tile_handles_control_flow_result.pto | 14 +++----- tools/ptoas/ptoas.cpp | 3 +- 11 files changed, 104 insertions(+), 30 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 03173019ba..40e7701653 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6979,11 +6979,7 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxFp8InputType(Type ty) { - std::string text; - llvm::raw_string_ostream os(text); - ty.print(os); - os.flush(); - return text == "f8E4M3FN" || text == "f8E5M2"; + return ty && isa(ty); } static bool isA5MxInputTypePair(Type lhsTy, Type rhsTy) { diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index d4a69f6b74..c7594a7da7 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -171,6 +171,23 @@ static std::optional resolveTileBufMemorySpace(StringRef locStr) { .Default(::std::nullopt); } +static BLayout resolveTileBufBLayout(MLIRContext *context, + AddressSpace memorySpace, + BLayout parsedLayout) { + if (memorySpace != AddressSpace::LEFT) + return parsedLayout; + + switch (getPTOParserTargetArch(context)) { + case PTOParserTargetArch::A3: + return BLayout::RowMajor; + case PTOParserTargetArch::A5: + return BLayout::ColMajor; + case PTOParserTargetArch::Unspecified: + return parsedLayout; + } + return parsedLayout; +} + TileBufConfigAttr TileBufType::getConfigAttr() const { // 情况 A:getConfig() 已经是 TileBufConfigAttr if constexpr (std::is_same_v) { @@ -500,10 +517,14 @@ static Type buildTileBufType(AsmParser &parser, return Type(); } + BLayout effectiveBLayout = + resolveTileBufBLayout(parser.getContext(), memorySpace.value(), + bl.value()); + // (32-byte alignment and boxed layout divisibility checks removed // - not general hardware requirements; validation handled elsewhere) - auto blAttr = BLayoutAttr::get(ctx, bl.value()); + auto blAttr = BLayoutAttr::get(ctx, effectiveBLayout); auto slAttr = SLayoutAttr::get(ctx, sl.value()); auto fractalAttr = IntegerAttr::get(IntegerType::get(ctx, kI32BitWidth), fields.fractal); diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 34daedfe31..087d8d62c6 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -200,8 +200,8 @@ static std::string getDtypeString(Type elemTy) { if (elemTy.isF32()) return "f32"; if (elemTy.isF16()) return "f16"; if (elemTy.isBF16()) return "bf16"; - if (pto::isPTOFloat8E4M3LikeType(elemTy)) return "f8e4m3"; - if (pto::isPTOFloat8E5M2LikeType(elemTy)) return "f8e5m2"; + if (isa(elemTy)) return "f8e4m3"; + if (isa(elemTy)) return "f8e5m2"; if (isa(elemTy)) return "hif8"; if (isa(elemTy)) return "f4e1m2x2"; if (isa(elemTy)) return "f4e2m1x2"; diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 984bf86b04..10cf9da2bd 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -632,6 +632,34 @@ static void cloneBlockWithoutTerminator(Block *from, Block *to, } } +static bool isOperationNestedIn(Operation *op, Operation *ancestor) { + for (; op; op = op->getParentOp()) { + if (op == ancestor) + return true; + } + return false; +} + +static bool isValueOwnedByOperation(Value value, Operation *owner) { + if (auto result = dyn_cast(value)) + return isOperationNestedIn(result.getOwner(), owner); + if (auto arg = dyn_cast(value)) + return isOperationNestedIn(arg.getOwner()->getParentOp(), owner); + return false; +} + +static void eraseTileHandlesOwnedBy(Operation *owner, + DenseMap &tileHandles) { + SmallVector keysToErase; + for (auto &entry : tileHandles) { + if (isValueOwnedByOperation(entry.first, owner) || + isValueOwnedByOperation(entry.second, owner)) + keysToErase.push_back(entry.first); + } + for (Value key : keysToErase) + tileHandles.erase(key); +} + static FailureOr materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { bool changed = false; @@ -720,6 +748,7 @@ materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { tileHandles[newIf.getResult(idx)] = newIf.getResult(idx); } + eraseTileHandlesOwnedBy(ifOp, tileHandles); ifOp.erase(); changed = true; } @@ -818,6 +847,7 @@ materializeSCFForResults(ModuleOp module, DenseMap &tileHandles) { tileHandles[newFor.getResult(idx)] = newFor.getResult(idx); } + eraseTileHandlesOwnedBy(forOp, tileHandles); forOp.erase(); changed = true; } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index eeae14f65b..0e8a8a6717 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -133,6 +133,17 @@ static Value loadEmitCVariableIfNeeded(OpBuilder &builder, Location loc, return value; } +static Value getSourceEmitCVariable(Value value) { + if (value.getDefiningOp()) + return value; + if (auto loadOp = value.getDefiningOp()) { + Value operand = loadOp.getOperand(); + if (operand.getDefiningOp()) + return operand; + } + return {}; +} + [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = "__pto.lowered_set_validshape"; [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = @@ -14325,7 +14336,7 @@ static AICORE inline void ptoas_auto_sync_tail( return; if (callOp.getNumOperands() != 1 || callOp.getNumResults() != 1) return; - if (callOp.getOperand(0).getDefiningOp()) + if (getSourceEmitCVariable(callOp.getOperand(0))) tileDataReadsToSink.push_back(callOp); }); @@ -14376,6 +14387,20 @@ static AICORE inline void ptoas_auto_sync_tail( continue; // 这是一个赋值操作,不算有效使用 } } + if (auto load = dyn_cast(user)) { + bool loadOnlyFeedsDeadTileDataReads = true; + for (Operation *loadUser : load.getResult().getUsers()) { + auto call = dyn_cast(loadUser); + if (!call || call.getCallee() != "PTOAS__TILE_DATA" || + call.getNumResults() != 1 || + !call.getResult(0).use_empty()) { + loadOnlyFeedsDeadTileDataReads = false; + break; + } + } + if (loadOnlyFeedsDeadTileDataReads) + continue; + } // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 isRead = true; break; @@ -14390,7 +14415,14 @@ static AICORE inline void ptoas_auto_sync_tail( // 1. 先删除所有使用该变量的 TASSIGN llvm::SmallVector usersToErase; for (Operation* user : varOp.getResult().getUsers()) { - // 我们上面已经确认过,剩下的 user 只能是 TASSIGN + if (auto load = dyn_cast(user)) { + llvm::SmallVector loadUsersToErase; + for (Operation *loadUser : load.getResult().getUsers()) + loadUsersToErase.push_back(loadUser); + for (Operation *loadUser : loadUsersToErase) + loadUser->erase(); + } + // 我们上面已经确认过,剩下的 user 只能是 TASSIGN 或死 emitc.load usersToErase.push_back(user); } for (auto u : usersToErase) u->erase(); diff --git a/test/lit/pto/compact_left_blayout_parser_a5.pto b/test/lit/pto/compact_left_blayout_parser_a5.pto index 554c9fa6a9..4dad24114e 100644 --- a/test/lit/pto/compact_left_blayout_parser_a5.pto +++ b/test/lit/pto/compact_left_blayout_parser_a5.pto @@ -8,4 +8,4 @@ module attributes {"pto.target_arch" = "a5"} { } // CHECK-LABEL: func.func @compact_left_blayout_parser_a5() { -// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> +// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> diff --git a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto index 72c58c9ed9..953c70ed44 100644 --- a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto +++ b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto @@ -30,8 +30,8 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind [[SINK_BASE:v[0-9]+]]; // CHECK-NEXT: Tile [[SINK_TILE:v[0-9]+]] = [[SINK_BASE]]; -// CHECK: __ubuf__ float* [[SINK_PTR:v[0-9]+]] = [[SINK_TILE]].data(); -// CHECK: __ubuf__ float* [[ADDR_PTR:v[0-9]+]] = [[ADDR_TILE]].data(); -// CHECK: uint64_t [[ADDR_BITS:v[0-9]+]] = reinterpret_cast([[ADDR_PTR]]); -// CHECK: TASSIGN([[SINK_TILE]], [[ADDR_BITS]]); -// CHECK: sink_ptr([[SINK_PTR]]); +// CHECK-NEXT: __ubuf__ float* [[ADDR_PTR:v[0-9]+]] = [[ADDR_TILE]].data(); +// CHECK-NEXT: uint64_t [[ADDR_BITS:v[0-9]+]] = reinterpret_cast([[ADDR_PTR]]); +// CHECK-NEXT: TASSIGN([[SINK_TILE]], [[ADDR_BITS]]); +// CHECK-NEXT: __ubuf__ float* [[SINK_PTR:v[0-9]+]] = [[SINK_TILE]].data(); +// CHECK-NEXT: sink_ptr([[SINK_PTR]]); diff --git a/test/lit/pto/left_blayout_parser_a3.pto b/test/lit/pto/left_blayout_parser_a3.pto index 41744472a5..21ee0af7fd 100644 --- a/test/lit/pto/left_blayout_parser_a3.pto +++ b/test/lit/pto/left_blayout_parser_a3.pto @@ -4,8 +4,8 @@ module attributes {"pto.target_arch" = "a3"} { func.func @left_blayout_parser_a3() { %c0 = arith.constant 0 : index %src = pto.alloc_tile : !pto.tile_buf - %dst = pto.alloc_tile : !pto.tile_buf - pto.textract ins(%src, %c0, %c0 : !pto.tile_buf, index, index) outs(%dst : !pto.tile_buf) + %dst = pto.alloc_tile : !pto.tile_buf + pto.textract ins(%src, %c0, %c0 : !pto.tile_buf, index, index) outs(%dst : !pto.tile_buf) return } } diff --git a/test/lit/pto/left_blayout_parser_a5.pto b/test/lit/pto/left_blayout_parser_a5.pto index a689597add..a8c7db18c2 100644 --- a/test/lit/pto/left_blayout_parser_a5.pto +++ b/test/lit/pto/left_blayout_parser_a5.pto @@ -11,4 +11,4 @@ module attributes {"pto.target_arch" = "a5"} { } // CHECK-LABEL: AICORE void left_blayout_parser_a5() { -// CHECK: Tile +// CHECK: Tile diff --git a/test/lit/pto/materialize_tile_handles_control_flow_result.pto b/test/lit/pto/materialize_tile_handles_control_flow_result.pto index f927d47f86..4fba96e31b 100644 --- a/test/lit/pto/materialize_tile_handles_control_flow_result.pto +++ b/test/lit/pto/materialize_tile_handles_control_flow_result.pto @@ -1,7 +1,4 @@ -// RUN: rm -f %t.seam -// RUN: ptoas --pto-arch=a3 --pto-seam-ir-file=%t.seam %s -o %t.cpp -// RUN: FileCheck %s --check-prefix=IR < %t.seam -// RUN: FileCheck %s --check-prefix=SEAM-EMITC < %t.cpp +// RUN: ptoas --pto-arch=a3 --mlir-print-ir-after=pto-materialize-tile-handles %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=IR // RUN: ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC module { @@ -51,12 +48,12 @@ module { } } -// IR-LABEL: func.func @control_flow_result_tile_materialization -// IR: %[[SEL:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : !pto.tile_buf -// IR: scf.if {{.*}} { +// IR: %[[SEL:.*]] = scf.if {{.*}} -> (!pto.tile_buf) { // IR: pto.tadd +// IR: scf.yield %{{.*}} : !pto.tile_buf // IR: } else { // IR: pto.tmul +// IR: scf.yield %{{.*}} : !pto.tile_buf // IR: } // IR: pto.tadd ins(%[[SEL]], %[[SEL]] // IR-LABEL: func.func @loop_carried_tile_materialization @@ -66,9 +63,6 @@ module { // IR: } // IR: pto.tmul ins(%[[RES]], %[[RES]] -// SEAM-EMITC-LABEL: AICORE void control_flow_result_tile_materialization -// SEAM-EMITC: TADD( - // EMITC-LABEL: AICORE void control_flow_result_tile_materialization // EMITC: if ( // EMITC: TADD( diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 58d4544242..22f53794c0 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1874,7 +1874,8 @@ int mlir::pto::compilePTOASModule( char **argv = context.getArgv(); if (effectiveBackend != PTOBackend::VPTO && - (emitVPTO || emitVPTOLLVMDialect)) { + (emitVPTO || emitVPTOLLVMDialect || ptoPrintSeamIR || + !ptoSeamIRFile.empty())) { llvm::errs() << "Error: VPTO-specific flags require " "--pto-backend=vpto or pto.backend = \"vpto\".\n"; return 1; From 15eb9af7c56ed62a957ca781fc48e0c095a2247f Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 1 Jul 2026 16:57:55 +0800 Subject: [PATCH 50/51] test: update mgather mscatter emitc checks for LLVM21 --- test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto index a676cd3dda..e15e4c82b0 100644 --- a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto +++ b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto @@ -87,13 +87,13 @@ module attributes {pto.target_arch = "a5"} { } // CHECK-LABEL: AICORE void mgather_emitc +// CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); // CHECK-DAG: __ubuf__ int32_t* [[MGATHER_IDX:v[0-9]+]] = {{.*}}; // CHECK-DAG: __ubuf__ float* [[MGATHER_DST:v[0-9]+]] = {{.*}}; -// CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: MGATHER([[MGATHER_DST]], {{v[0-9]+}}, [[MGATHER_IDX]]); // CHECK-LABEL: AICORE void mscatter_emitc -// CHECK-DAG: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; -// CHECK-DAG: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); +// CHECK-DAG: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; // CHECK: MSCATTER({{v[0-9]+}}, [[MSCATTER_SRC]], [[MSCATTER_IDX]]); From ed75c3637c9e4aa3288b0004116f7260ad1c68ea Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 2 Jul 2026 10:31:14 +0800 Subject: [PATCH 51/51] fix: restore A5 left layout verifier contract --- lib/PTO/IR/PTO.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 40e7701653..c63e216e24 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4815,10 +4815,8 @@ static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, if (!lhsTb || !rhsTb || !dstTb) return success(); - if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError( - "expects lhs to use the row_major or col_major blayout on A5"); + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects lhs to use the col_major blayout on A5"); if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) return op->emitOpError("expects rhs to use the row_major blayout on A5"); if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) @@ -6623,11 +6621,9 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); if (*dstSpace == pto::AddressSpace::LEFT) { - if ((dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) || + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError( - "expects A5 left dst to use row_major or col_major blayout and row_major slayout"); + return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); } else if (*dstSpace == pto::AddressSpace::RIGHT) { if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor))