diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 209af90f95..ddc51116c4 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -25,8 +25,8 @@ permissions: env: LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto - LLVM_CACHE_FLAVOR: release-hardening-v1 + LLVM_REF: feature-vpto-llvm21 + LLVM_CACHE_FLAVOR: llvm21-vpto-release-hardening-v1 jobs: build_wheel: @@ -143,6 +143,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ + -DPython_EXECUTABLE=${PY_PATH}/bin/python \ + -Dpybind11_DIR="$(${PY_PATH}/bin/python -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(${PY_PATH}/bin/python -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C $LLVM_BUILD_DIR @@ -159,6 +162,9 @@ jobs: - name: Build PTOAS run: | export PATH="${PY_PATH}/bin:$PATH" + if [ "${{ matrix.arch }}" = "aarch64" ]; then + export CMAKE_BUILD_PARALLEL_LEVEL=2 + fi PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ pip install . --no-build-isolation diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 028f8c74fa..37d9ad5468 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -24,8 +24,8 @@ permissions: env: LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto - LLVM_CACHE_FLAVOR: release-v2 + LLVM_REF: feature-vpto-llvm21 + LLVM_CACHE_FLAVOR: llvm21-vpto-release-v1 jobs: build_wheel: @@ -145,6 +145,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python) \ + -DPython_EXECUTABLE=$(which python) \ + -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C $LLVM_BUILD_DIR diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d5dbd9b88..1e6696fe12 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,13 +98,13 @@ jobs: env: PTOAS_CLANG_MAJOR: "15" LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto + LLVM_REF: feature-vpto-llvm21 LLVM_BUILD_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert PTO_BUILD_DIR: ${{ github.workspace }}/build-assert PTO_INSTALL_DIR: ${{ github.workspace }}/install-assert MLIR_PYTHONPATH: ${{ github.workspace }}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core - LLVM_CACHE_FLAVOR: assert-shared-mlirpy-hardening-v2 + LLVM_CACHE_FLAVOR: llvm21-assert-shared-mlirpy-hardening-v1 steps: - name: Checkout uses: actions/checkout@v4 @@ -114,6 +114,15 @@ jobs: - name: Install dependencies run: | + # GitHub-hosted runners may include Microsoft apt sources unrelated to + # PTOAS; these can intermittently return 403 and break apt-get update. + sudo rm -f \ + /etc/apt/sources.list.d/azure-cli*.list \ + /etc/apt/sources.list.d/azure-cli*.sources \ + /etc/apt/sources.list.d/microsoft-prod*.list \ + /etc/apt/sources.list.d/microsoft-prod*.sources \ + /etc/apt/sources.list.d/packages-microsoft-prod*.list \ + /etc/apt/sources.list.d/packages-microsoft-prod*.sources sudo apt-get update sudo apt-get install -y \ cmake git ninja-build \ @@ -122,7 +131,7 @@ jobs: libedit-dev zlib1g-dev libxml2-dev libzstd-dev python3 -m pip install --upgrade pip # LLVM/MLIR Python bindings are not yet compatible with pybind11 3.x. - python3 -m pip install setuptools wheel 'pybind11<3' numpy + python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy - name: Define payload paths shell: bash @@ -200,6 +209,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ + -DPython_EXECUTABLE=python3 \ + -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER="${PTOAS_CMAKE_C_COMPILER}" \ -DCMAKE_CXX_COMPILER="${PTOAS_CMAKE_CXX_COMPILER}" \ diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index f8b361b273..20862b519e 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -98,7 +98,7 @@ jobs: vpto-sim-validation: needs: detect-vpto-sim-changes runs-on: [self-hosted, Linux, X64, label-1] - timeout-minutes: 120 + timeout-minutes: 300 concurrency: group: vpto-sim-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true @@ -108,7 +108,7 @@ jobs: }} env: LLVM_REPO: https://github.com/vpto-dev/llvm-project.git - LLVM_REF: feature-vpto + LLVM_REF: feature-vpto-llvm21 PTO_INSTALL_DIR: ${{ github.workspace }}/install VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci @@ -193,11 +193,12 @@ jobs: python3 -c "import numpy" >/dev/null 2>&1 || need_pip_install=1 python3 -c "import setuptools, wheel" >/dev/null 2>&1 || need_pip_install=1 python3 -m pybind11 --cmakedir >/dev/null 2>&1 || need_pip_install=1 + python3 -m nanobind --cmake_dir >/dev/null 2>&1 || need_pip_install=1 python3 -c "import ml_dtypes" >/dev/null 2>&1 || need_pip_install=1 if [[ "${need_pip_install}" -eq 1 ]]; then python3 -m pip install --upgrade pip - python3 -m pip install setuptools wheel 'pybind11<3' numpy ml-dtypes + python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi - name: Clean CI work dirs @@ -251,6 +252,9 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ + -DPython_EXECUTABLE=python3 \ + -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" @@ -312,6 +316,7 @@ jobs: ref: ${{ env.PYPTO_REF }} path: ${{ env.PYPTO_WORKSPACE }} fetch-depth: 1 + submodules: recursive persist-credentials: false - name: Checkout PTO-ISA @@ -381,6 +386,29 @@ jobs: # auto-detecting onboard platforms from the runner's CANN environment. env -u ASCEND_HOME_PATH python3 -m pip install --no-build-isolation --no-deps "${PYPTO_WORKSPACE}/runtime" + python3 - <<'PY' + from pathlib import Path + + from simpler_setup.environment import PROJECT_ROOT + from simpler_setup.kernel_compiler import KernelCompiler + + required = PROJECT_ROOT / "src" / "common" / "task_interface" / "arg_direction.h" + if not required.is_file(): + raise SystemExit(f"missing required simpler runtime header: {required}") + + include_dirs = KernelCompiler(platform="a5sim").get_orchestration_include_dirs( + "tensormap_and_ringbuffer" + ) + if str(required.parent) not in include_dirs: + raise SystemExit( + "simpler orchestration include dirs do not contain " + f"{required.parent}: {include_dirs}" + ) + + print(f"simpler PROJECT_ROOT={PROJECT_ROOT}") + print(f"simpler arg_direction.h={required}") + PY + - name: Run PyPTO PTOAS end-to-end simulator smoke shell: bash env: diff --git a/CMakeLists.txt b/CMakeLists.txt index bb59460e98..80670d891c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,66 @@ if(POLICY CMP0169) set(CMAKE_POLICY_DEFAULT_CMP0169 OLD) endif() +function(ptoas_get_cmake_cache_value cache_file variable out_var) + if(NOT EXISTS "${cache_file}") + set(${out_var} "" PARENT_SCOPE) + return() + endif() + + file(STRINGS "${cache_file}" cache_lines + REGEX "^${variable}:[^=]*=") + if(cache_lines) + list(GET cache_lines 0 cache_line) + string(REGEX REPLACE "^[^=]*=" "" cache_value "${cache_line}") + set(${out_var} "${cache_value}" PARENT_SCOPE) + else() + set(${out_var} "" PARENT_SCOPE) + endif() +endfunction() + +function(ptoas_get_llvm_build_dir_from_llvm_dir out_var) + if(NOT DEFINED LLVM_DIR OR LLVM_DIR STREQUAL "") + set(${out_var} "" PARENT_SCOPE) + return() + endif() + + get_filename_component(llvm_cmake_dir "${LLVM_DIR}" REALPATH) + get_filename_component(llvm_build_dir "${llvm_cmake_dir}/../../.." REALPATH) + if(EXISTS "${llvm_build_dir}/CMakeCache.txt") + set(${out_var} "${llvm_build_dir}" PARENT_SCOPE) + else() + set(${out_var} "" PARENT_SCOPE) + endif() +endfunction() + +function(ptoas_preseed_compilers_from_llvm_cache) + ptoas_get_llvm_build_dir_from_llvm_dir(llvm_build_dir) + if(llvm_build_dir STREQUAL "") + return() + endif() + + set(llvm_cache "${llvm_build_dir}/CMakeCache.txt") + if(NOT DEFINED CMAKE_C_COMPILER AND "$ENV{CC}" STREQUAL "") + ptoas_get_cmake_cache_value("${llvm_cache}" "CMAKE_C_COMPILER" + llvm_c_compiler) + if(llvm_c_compiler AND EXISTS "${llvm_c_compiler}") + set(CMAKE_C_COMPILER "${llvm_c_compiler}" CACHE FILEPATH + "C compiler inherited from the configured LLVM build" FORCE) + endif() + endif() + + if(NOT DEFINED CMAKE_CXX_COMPILER AND "$ENV{CXX}" STREQUAL "") + ptoas_get_cmake_cache_value("${llvm_cache}" "CMAKE_CXX_COMPILER" + llvm_cxx_compiler) + if(llvm_cxx_compiler AND EXISTS "${llvm_cxx_compiler}") + set(CMAKE_CXX_COMPILER "${llvm_cxx_compiler}" CACHE FILEPATH + "C++ compiler inherited from the configured LLVM build" FORCE) + endif() + endif() +endfunction() + +ptoas_preseed_compilers_from_llvm_cache() + project(ptoas VERSION 0.49) option(PTOAS_VALIDATE_CMAKE4_FETCHCONTENT_COMPAT @@ -72,7 +132,7 @@ endif() message(STATUS "PTOAS CLI version: ${PTOAS_CLI_VERSION}") # ========================================================= -# [新增] 强制设置 C++17 标准 (LLVM 19 必需) +# [新增] 强制设置 C++17 标准 (LLVM/MLIR 必需) # ========================================================= set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -85,6 +145,41 @@ find_package(LLVM REQUIRED CONFIG) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "LLVM CMake Dir: ${LLVM_CMAKE_DIR}") message(STATUS "LLVM Include Dir: ${LLVM_INCLUDE_DIRS}") +if(NOT DEFINED LLVM_VERSION_MAJOR) + string(REGEX MATCH "^[0-9]+" LLVM_VERSION_MAJOR "${LLVM_PACKAGE_VERSION}") +endif() +if(NOT LLVM_VERSION_MAJOR STREQUAL "21") + message(FATAL_ERROR + "PTOAS requires LLVM 21 after the LLVM21 upgrade, but found " + "LLVM ${LLVM_PACKAGE_VERSION}. Please point LLVM_DIR/MLIR_DIR to an " + "LLVM 21 build, for example vpto-dev/llvm-project:feature-vpto-llvm21.") +endif() +get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR + "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) +ptoas_get_llvm_build_dir_from_llvm_dir(PTO_LLVM_BUILD_DIR) +if(PTO_LLVM_BUILD_DIR) + ptoas_get_cmake_cache_value("${PTO_LLVM_BUILD_DIR}/CMakeCache.txt" + "CMAKE_CXX_COMPILER" + PTO_LLVM_CXX_COMPILER) +endif() +if(PTO_LLVM_CXX_COMPILER AND EXISTS "${PTO_LLVM_CXX_COMPILER}") + execute_process( + COMMAND "${PTO_LLVM_CXX_COMPILER}" --version + OUTPUT_VARIABLE PTO_LLVM_CXX_VERSION_TEXT + ERROR_VARIABLE PTO_LLVM_CXX_VERSION_TEXT + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE) + if(PTO_LLVM_CXX_VERSION_TEXT MATCHES "[Cc]lang" + AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + message(FATAL_ERROR + "PTOAS is being configured with GNU C++ (${CMAKE_CXX_COMPILER}), " + "but the selected LLVM build was compiled with a Clang-compatible " + "compiler (${PTO_LLVM_CXX_COMPILER}). MLIR TypeID fallback strings are " + "compiler-family sensitive across shared libraries; use " + "-DCMAKE_CXX_COMPILER=${PTO_LLVM_CXX_COMPILER} or leave the compiler " + "unset so PTOAS can inherit it from LLVM_DIR.") + endif() +endif() # 将 LLVM 模块路径加入 CMake 搜索路径 list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") @@ -131,7 +226,7 @@ include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) -link_directories(${LLVM_BUILD_LIBRARY_DIR}) +link_directories(${PTO_LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) # ========================================================= diff --git a/README.md b/README.md index 0d8399783f..3457d7eda6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## 1. 项目简介 (Introduction) -**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR (llvmorg-19.1.7)***(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)* 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 +**ptoas** (`ptoas`) 是一个基于 **LLVM/MLIR LLVM21 VPTO 分支 (`vpto-dev/llvm-project:feature-vpto-llvm21`)** 框架构建的专用编译器工具链,专为 **PTO Bytecode** (Programming Tiling Operator Bytecode) 设计。 作为连接上层 AI 框架与底层各类NPU/GPGPU/CPU硬件,`ptoas` 采用 **Out-of-Tree** 架构构建,提供了完整的 C++ 与 Python 接口,主要职责包括: @@ -37,7 +37,7 @@ PTOAS/ ## 3. 构建指南 (Build Instructions) -⚠️ **重要提示**:本项目严格依赖 **LLVM llvmorg-19.1.7** 版本。 +⚠️ **重要提示**:本项目严格依赖 **LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21`**。 ### 3.0 环境变量配置 (Configuration) @@ -51,11 +51,11 @@ export WORKSPACE_DIR=$HOME/llvm-workspace # LLVM 源码与构建路径 export LLVM_SOURCE_DIR=$WORKSPACE_DIR/llvm-project -export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared +export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared-21 # PTOAS 源码与安装路径 export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS -export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-llvm21 # ======================================================= # 创建工作目录 @@ -69,30 +69,31 @@ mkdir -p $WORKSPACE_DIR * **Compiler**: GCC >= 9 或 Clang (支持 C++17) * **Build System**: CMake >= 3.20, Ninja * **Python**: 3.8+ -* **Python Packages**: `pybind11`, `numpy` +* **Python Packages**: `pybind11<3`, `nanobind`, `numpy` ```bash -python3 -m pip install pybind11==2.12.0 numpy +python3 -m pip install 'pybind11<3' nanobind numpy ``` -> 说明:当前 LLVM/MLIR Python 绑定与 `pybind11` 3.x 不兼容。 +> 说明:当前 PTOAS Python 扩展继续使用 `pybind11`,LLVM21 的 MLIR Python 绑定构建需要 `nanobind`。 +> 当前 LLVM/MLIR Python 绑定与 `pybind11` 3.x 不兼容。 > 如果编译 LLVM 时遇到 `def_property family does not currently support keep_alive` 等报错, -> 请先执行上面的降级命令。 +> 请确认使用上面的 `pybind11<3` 依赖。 ### 3.2 第一步:构建 LLVM/MLIR (Dependency) -我们需要下载 LLVM 源码,切换到 `llvmorg-19.1.7` 标签,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 +我们需要下载 VPTO 适配后的 LLVM 源码,切换到 `feature-vpto-llvm21` 分支,并以**动态库 (Shared Libs)** 模式编译,以确保 Python Binding 的正确链接。 ```bash # 1. 下载 LLVM 源码 cd $WORKSPACE_DIR -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [关键] 切换到 llvmorg-19.1.7 -git checkout llvmorg-19.1.7 +# 2. [关键] 切换到 VPTO 适配分支 +git checkout feature-vpto-llvm21 # 3. 配置 CMake (构建动态库并启用 Python 绑定) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ @@ -100,6 +101,9 @@ cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ -DBUILD_SHARED_LIBS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python3) \ + -DPython_EXECUTABLE=$(which python3) \ + -Dpybind11_DIR=$(python3 -m pybind11 --cmakedir) \ + -Dnanobind_DIR=$(python3 -m nanobind --cmake_dir) \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" @@ -110,7 +114,7 @@ ninja -C $LLVM_BUILD_DIR ### 3.3 第二步:构建 PTOAS (Out-of-Tree) -下载 PTOAS 源码并基于刚刚编译好的 LLVM 19 进行构建。 +下载 PTOAS 源码并基于刚刚编译好的 LLVM 21 进行构建。 ```bash # 1. 下载 PTOAS 源码 @@ -125,7 +129,7 @@ export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) # 注意:此处直接使用了 3.0 章节中定义的变量,无需手动修改 cmake -G Ninja \ -S . \ - -B build \ + -B build-llvm21 \ -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ -DPython3_EXECUTABLE=$(which python3) \ @@ -136,12 +140,12 @@ cmake -G Ninja \ -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" # 4. 编译并安装 -ninja -C build -ninja -C build install +ninja -C build-llvm21 +ninja -C build-llvm21 install # 5. 检查构建产物 # build 输出(便于本地开发/调试) -$PTO_SOURCE_DIR/build/python/ +$PTO_SOURCE_DIR/build-llvm21/python/ ├── mlir │ ├── _mlir_libs │ │ └── _pto.cpython-*.so @@ -159,8 +163,8 @@ $PTO_INSTALL_DIR/ └── _pto.cpython-*.so # CLI 工具 -$PTO_SOURCE_DIR/build/tools/ptoas/ptoas -$PTO_SOURCE_DIR/build/tools/ptobc/ptobc +$PTO_SOURCE_DIR/build-llvm21/tools/ptoas/ptoas +$PTO_SOURCE_DIR/build-llvm21/tools/ptobc/ptobc ``` @@ -183,7 +187,7 @@ export PYTHONPATH=$PTO_PYTHON_ROOT:$MLIR_PYTHON_ROOT:$PYTHONPATH export LD_LIBRARY_PATH=$LLVM_BUILD_DIR/lib:$PTO_INSTALL_DIR/lib:$LD_LIBRARY_PATH # 3. PATH: 将 ptoas / ptobc 添加到命令行路径 -export PATH=$PTO_SOURCE_DIR/build/tools/ptoas:$PTO_SOURCE_DIR/build/tools/ptobc:$PATH +export PATH=$PTO_SOURCE_DIR/build-llvm21/tools/ptoas:$PTO_SOURCE_DIR/build-llvm21/tools/ptobc:$PATH ``` diff --git a/README_en.md b/README_en.md index b7a060a0fd..2832349715 100644 --- a/README_en.md +++ b/README_en.md @@ -2,7 +2,7 @@ ## 1. Introduction -**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-19.1.7)** *(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). +**ptoas** is a specialized compiler toolchain built on top of the **LLVM21 VPTO branch (`vpto-dev/llvm-project:feature-vpto-llvm21`)**, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: @@ -36,7 +36,7 @@ PTOAS/ ## 3. Build Instructions -⚠️ **Important**: This project strictly requires **LLVM llvmorg-19.1.7**. +⚠️ **Important**: This project strictly requires the **LLVM21 VPTO branch `vpto-dev/llvm-project:feature-vpto-llvm21`**. ### 3.0 Environment Variable Configuration @@ -67,10 +67,10 @@ mkdir -p $WORKSPACE_DIR * **Compiler**: GCC >= 9 or Clang (C++17 support required) * **Build System**: CMake >= 3.20, Ninja * **Python**: 3.8+ -* **Python Packages**: `pybind11`, `numpy` +* **Python Packages**: `pybind11<3`, `nanobind`, `numpy` ```bash -python3 -m pip install pybind11==2.12.0 numpy +python3 -m pip install "pybind11<3" nanobind numpy ``` > **Note**: The current LLVM/MLIR Python bindings are not compatible with `pybind11` 3.x. @@ -79,16 +79,16 @@ python3 -m pip install pybind11==2.12.0 numpy ### 3.2 Step 1: Build LLVM/MLIR (Dependency) -Download the LLVM source, check out the `llvmorg-19.1.7` tag, and build with **shared libraries** to ensure correct linking for Python bindings. +Download the VPTO-adapted LLVM source, check out the `feature-vpto-llvm21` branch, and build with **shared libraries** to ensure correct linking for Python bindings. ```bash # 1. Clone LLVM cd $WORKSPACE_DIR -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd $LLVM_SOURCE_DIR -# 2. [Critical] Check out llvmorg-19.1.7 -git checkout llvmorg-19.1.7 +# 2. [Critical] Check out the VPTO adaptation branch +git checkout feature-vpto-llvm21 # 3. Configure CMake (build shared libs with Python bindings enabled) cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ @@ -105,7 +105,7 @@ ninja -C $LLVM_BUILD_DIR ### 3.3 Step 2: Build PTOAS (Out-of-Tree) -Clone the PTOAS source and build against the LLVM 19 you just compiled. +Clone the PTOAS source and build against the LLVM 21 you just compiled. ```bash # 1. Clone PTOAS diff --git a/ReleaseNotes.md b/ReleaseNotes.md index f94982946c..3beaf43603 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -8,7 +8,7 @@ - PTOAS 首次发布 ## 概述 -PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR release/19.x 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 +PTOAS(PTO Assembler & Optimizer)是面向 PTO Bytecode 的编译器工具链,基于 LLVM/MLIR LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21` 构建。它提供 PTO Dialect 的定义、解析、验证、优化与代码生成能力,并输出可调用 `pto-isa` 的 C++ 代码。 PTOAS很快将集成到以下框架中,敬请期待 - PyPTO diff --git a/docker/Dockerfile b/docker/Dockerfile index 553ec9d791..c1979b563b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,8 @@ ARG ARCH # NOTE: change $PY_VER for different Python versions (3.8 - 3.14 available) ARG PY_VER=cp311-cp311 -ARG LLVM_TAG=llvmorg-19.1.7 +ARG LLVM_REF=feature-vpto-llvm21 +ARG LLVM_REPO=https://github.com/vpto-dev/llvm-project.git ## -- usually no need to change below -- @@ -19,7 +20,7 @@ ENV PATH="${PY_PATH}/bin:${PATH}" # dependency RUN dnf install -y ninja-build cmake git ccache gcc-c++ lld zip binutils patchelf chrpath && dnf clean all -RUN pip install --no-cache-dir numpy pybind11 nanobind setuptools wheel auditwheel +RUN pip install --no-cache-dir numpy 'pybind11<3' nanobind setuptools wheel auditwheel COPY cmake/LinuxHardeningCache.cmake /tmp/LinuxHardeningCache.cmake @@ -34,7 +35,7 @@ ENV PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-release # build LLVM WORKDIR $WORKSPACE_DIR -RUN git clone --depth 1 --branch ${LLVM_TAG} https://github.com/llvm/llvm-project.git +RUN git clone --depth 1 --branch ${LLVM_REF} ${LLVM_REPO} llvm-project WORKDIR $LLVM_SOURCE_DIR RUN cmake -C /tmp/LinuxHardeningCache.cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ @@ -43,6 +44,9 @@ RUN cmake -C /tmp/LinuxHardeningCache.cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR -DLLVM_ENABLE_ASSERTIONS=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ + -DPython_EXECUTABLE=${PY_PATH}/bin/python \ + -Dpybind11_DIR=$(${PY_PATH}/bin/python -m pybind11 --cmakedir) \ + -Dnanobind_DIR=$(${PY_PATH}/bin/python -m nanobind --cmake_dir) \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md index 6c7495a5ca..7b09173b27 100644 --- a/docs/build_with_installed_llvm.md +++ b/docs/build_with_installed_llvm.md @@ -2,7 +2,7 @@ 本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: -- LLVM/MLIR `19.1.7` 已经构建并安装完成。 +- LLVM/MLIR LLVM21 VPTO 分支 `vpto-dev/llvm-project:feature-vpto-llvm21` 已经构建并安装完成。 - LLVM 安装路径固定为 `/opt/llvm`。 - `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 @@ -42,11 +42,12 @@ mkdir -p "$WORKSPACE_DIR" - CMake >= 3.20 - Ninja - Python 3.8+ -- `pybind11` +- `pybind11<3` +- `nanobind` - `numpy` ```bash -pip3 install pybind11 numpy +pip3 install "pybind11<3" nanobind numpy ``` ## 跳过 3.2 @@ -62,9 +63,11 @@ README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 输出为: ```text -19.1.7 +21.1.8 ``` +实际源码基线应来自 `https://github.com/vpto-dev/llvm-project.git` 的 `feature-vpto-llvm21` 分支。 + ## 3.3 第二步:构建 ptoas 这里沿用 README 第 3.3 节的流程,但 `LLVM_DIR` 和 `MLIR_DIR` 需要改为 diff --git a/docs/designs/ci-board-validation-guide.md b/docs/designs/ci-board-validation-guide.md index b87feefd9f..bb963841d9 100644 --- a/docs/designs/ci-board-validation-guide.md +++ b/docs/designs/ci-board-validation-guide.md @@ -81,9 +81,9 @@ ### 3.1 构建 LLVM/MLIR ```bash -git clone https://github.com/llvm/llvm-project.git +git clone https://github.com/vpto-dev/llvm-project.git cd llvm-project -git checkout llvmorg-19.1.7 +git checkout feature-vpto-llvm21 cmake -G Ninja -S llvm -B llvm/build-shared \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ diff --git a/include/PTO/IR/PTOTypeUtils.h b/include/PTO/IR/PTOTypeUtils.h index f7db223fbc..f060c03912 100644 --- a/include/PTO/IR/PTOTypeUtils.h +++ b/include/PTO/IR/PTOTypeUtils.h @@ -9,12 +9,36 @@ #ifndef PTO_IR_PTOTYPEUTILS_H #define PTO_IR_PTOTYPEUTILS_H +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::pto { +namespace detail { +template +inline auto getPTOMemRefStridesAndOffsetImpl( + MemRefT memTy, SmallVectorImpl &strides, int64_t &offset, int) + -> decltype(memTy.getStridesAndOffset(strides, offset)) { + return memTy.getStridesAndOffset(strides, offset); +} + +template +inline LogicalResult getPTOMemRefStridesAndOffsetImpl( + MemRefT memTy, SmallVectorImpl &strides, int64_t &offset, long) { + return getStridesAndOffset(memTy, strides, offset); +} +} // namespace detail + +inline LogicalResult getPTOMemRefStridesAndOffset( + MemRefType memTy, SmallVectorImpl &strides, int64_t &offset) { + return detail::getPTOMemRefStridesAndOffsetImpl(memTy, strides, offset, 0); +} + bool isPTOFloat8Type(Type t); +bool isPTOFloat8E4M3LikeType(Type t); +bool isPTOFloat8E5M2LikeType(Type t); bool isPTOHiFloat8Type(Type t); bool isPTOF8E8M0Type(Type t); bool isPTOHiFloat8x2Type(Type t); diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 253e36a22a..2cd71fed22 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -72,13 +72,13 @@ def PTO_V2F8E4M3FNType : Type< CPred<"::llvm::isa<::mlir::VectorType>($_self) && " "::llvm::cast<::mlir::VectorType>($_self).getRank() == 1 && " "::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 2 && " - "::llvm::cast<::mlir::VectorType>($_self).getElementType().isFloat8E4M3FN()">, + "::llvm::isa<::mlir::Float8E4M3FNType>(::llvm::cast<::mlir::VectorType>($_self).getElementType())">, "vector<2xf8E4M3FN>">; def PTO_V2F8E5M2Type : Type< CPred<"::llvm::isa<::mlir::VectorType>($_self) && " "::llvm::cast<::mlir::VectorType>($_self).getRank() == 1 && " "::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 2 && " - "::llvm::cast<::mlir::VectorType>($_self).getElementType().isFloat8E5M2()">, + "::llvm::isa<::mlir::Float8E5M2Type>(::llvm::cast<::mlir::VectorType>($_self).getElementType())">, "vector<2xf8E5M2>">; def PTO_V2HiF8Type : Type< CPred<"::llvm::isa<::mlir::pto::HiF8x2Type>($_self)">, diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index c87bdb30c5..4e484668bd 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -19,6 +19,9 @@ find_package(pybind11 CONFIG REQUIRED) include(TableGen) include(AddMLIR) +get_filename_component(PTO_LLVM_BUILD_LIBRARY_DIR + "${LLVM_BUILD_LIBRARY_DIR}" REALPATH) + # ---- 1) Python native extension: mlir._mlir_libs._pto ---- pybind11_add_module(_pto MODULE PTOModule.cpp @@ -54,8 +57,8 @@ target_link_libraries(_pto PRIVATE # 关键:放到 mlir/_mlir_libs 下(匹配 MLIR dialect python 的 import 习惯) set_target_properties(_pto PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/python/mlir/_mlir_libs" - BUILD_RPATH "${LLVM_BUILD_LIBRARY_DIR}" - INSTALL_RPATH "${LLVM_BUILD_LIBRARY_DIR}" + BUILD_RPATH "${PTO_LLVM_BUILD_LIBRARY_DIR}" + INSTALL_RPATH "${PTO_LLVM_BUILD_LIBRARY_DIR}" ) # macOS: 避免 ld 警告 "-undefined suppress is deprecated",仅保留 -undefined dynamic_lookup if(APPLE) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 3b50fff4c7..c63e216e24 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3776,7 +3776,7 @@ static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, SmallVector strides; int64_t offset = 0; - if (failed(getStridesAndOffset(memTy, strides, offset))) + if (failed(getPTOMemRefStridesAndOffset(memTy, strides, offset))) return op->emitOpError() << "expects " << name << " to be a strided memref with a known layout"; @@ -3997,10 +3997,7 @@ static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { return true; if (!isTargetArchA5(op)) return false; - if (isPTOHiFloat8Type(ty)) - return true; - return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); + return isPTOHiFloat8Type(ty) || isPTOFloat8Type(ty); } static bool isSupportedMScatterAtomicPayloadElemType(Type ty, @@ -4277,8 +4274,7 @@ static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { return false; return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32() || isPTOHiFloat8Type(dstElem) || - dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || - dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); + isPTOFloat8E4M3LikeType(dstElem); } static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { @@ -5202,11 +5198,7 @@ static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, return success(); auto isA5TMatmulFp8Type = [](Type ty) { - if (auto ft = mlir::dyn_cast(ty)) - return ft.isFloat8E4M3() || ft.isFloat8E4M3FN() || - ft.isFloat8E4M3FNUZ() || ft.isFloat8E4M3B11FNUZ() || - ft.isFloat8E5M2() || ft.isFloat8E5M2FNUZ(); - return false; + return isPTOFloat8Type(ty); }; if (isA5 && dstElemTy.isF32()) { if (isA5TMatmulFp8Type(lhsElemTy) && isA5TMatmulFp8Type(rhsElemTy)) @@ -6983,9 +6975,7 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxFp8InputType(Type ty) { - if (auto ft = dyn_cast(ty)) - return ft.isFloat8E4M3FN() || ft.isFloat8E5M2(); - return false; + return ty && isa(ty); } static bool isA5MxInputTypePair(Type lhsTy, Type rhsTy) { @@ -13445,7 +13435,7 @@ mlir::LogicalResult mlir::pto::SimdTileToMemrefOp::verify() { SmallVector memStrides; int64_t memOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(memTy, memStrides, memOffset))) + if (failed(getPTOMemRefStridesAndOffset(memTy, memStrides, memOffset))) return emitOpError("expects memref to use strided layout"); if (memOffset != 0) return emitOpError("expects memref offset to be 0"); @@ -14710,8 +14700,7 @@ static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { needsComma = true; }; - if (op.getId() != 0) - printClause("id", op.getId()); + printClause("id", op.getId()); printClause("dir_mask", static_cast(op.getDirMask())); printClause("slot_size", op.getSlotSize()); if (auto slotNumAttr = op.getSlotNumAttr()) diff --git a/lib/PTO/IR/PTOTypeUtils.cpp b/lib/PTO/IR/PTOTypeUtils.cpp index 610248b954..3591c05986 100644 --- a/lib/PTO/IR/PTOTypeUtils.cpp +++ b/lib/PTO/IR/PTOTypeUtils.cpp @@ -18,8 +18,16 @@ constexpr unsigned kBitsPerByte = 8; } // namespace bool mlir::pto::isPTOFloat8Type(Type t) { - return t.isFloat8E4M3() || t.isFloat8E4M3FN() || t.isFloat8E4M3FNUZ() || - t.isFloat8E4M3B11FNUZ() || t.isFloat8E5M2() || t.isFloat8E5M2FNUZ(); + return isPTOFloat8E4M3LikeType(t) || isPTOFloat8E5M2LikeType(t); +} + +bool mlir::pto::isPTOFloat8E4M3LikeType(Type t) { + return isa(t); +} + +bool mlir::pto::isPTOFloat8E5M2LikeType(Type t) { + return isa(t); } bool mlir::pto::isPTOHiFloat8Type(Type t) { return isa(t); } diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 6c30ccca71..908c604a16 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1420,10 +1420,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; diff --git a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp index 21d4b98305..f58a15a376 100644 --- a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp @@ -25,7 +25,7 @@ namespace { static LogicalResult bufferizeDestinationStyleOpInterface( RewriterBase &rewriter, DestinationStyleOpInterface op, - const BufferizationOptions &options, + const BufferizationOptions &options, const BufferizationState &state, bool supportMixedTensorBufferMode = true); template { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + const BufferizationState &state) const { return bufferizeDestinationStyleOpInterface( - rewriter, cast(op), options, + rewriter, cast(op), options, state, supportMixedTensorBufferMode); } }; @@ -59,7 +60,8 @@ struct PTOReadWriteDpsOpInterfaceBase /// Generic conversion for any DestinationStyleOpInterface on tensors. static LogicalResult bufferizeDestinationStyleOpInterface( RewriterBase &rewriter, DestinationStyleOpInterface op, - const BufferizationOptions &options, bool supportMixedTensorBufferMode) { + const BufferizationOptions &options, const BufferizationState &state, + bool supportMixedTensorBufferMode) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -83,7 +85,8 @@ static LogicalResult bufferizeDestinationStyleOpInterface( newOperands.push_back(opOperand.get()); continue; } - FailureOr buffer = getBuffer(rewriter, opOperand.get(), options); + FailureOr buffer = + getBuffer(rewriter, opOperand.get(), options, state); if (failed(buffer)) { return failure(); } @@ -95,7 +98,7 @@ static LogicalResult bufferizeDestinationStyleOpInterface( for (OpResult opResult : op->getOpResults()) { OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); FailureOr resultBuffer = - getBuffer(rewriter, opOperand->get(), options); + getBuffer(rewriter, opOperand->get(), options, state); if (failed(resultBuffer)) { return failure(); } @@ -120,13 +123,15 @@ struct PTOStoreOpInterface : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + const BufferizationState &state) const { auto dpsOp = cast(op); if (dpsOp.hasPureBufferSemantics()) { return success(); } if (dpsOp.hasPureTensorSemantics()) { - return bufferizeDestinationStyleOpInterface(rewriter, dpsOp, options); + return bufferizeDestinationStyleOpInterface(rewriter, dpsOp, options, + state); } // We only handle the case where fixpipe op's input is a tensor from // mmad and fixpipe op's output is a memref type. @@ -141,7 +146,7 @@ struct PTOStoreOpInterface OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - FailureOr buffer = getBuffer(rewriter, srcOp->get(), options); + FailureOr buffer = getBuffer(rewriter, srcOp->get(), options, state); if (failed(buffer)) { return failure(); } diff --git a/lib/PTO/Transforms/ConvertToPTOOp.cpp b/lib/PTO/Transforms/ConvertToPTOOp.cpp index ca3c18ce2a..7e72901f56 100644 --- a/lib/PTO/Transforms/ConvertToPTOOp.cpp +++ b/lib/PTO/Transforms/ConvertToPTOOp.cpp @@ -53,11 +53,11 @@ std::optional getLeftPadNum(PatternRewriter &rewriter, if (auto subviewOp = llvm::dyn_cast(user)) { auto offsets = subviewOp.getMixedOffsets(); auto offset = offsets[offsets.size() - 1]; - Value offsetValue = - offset.is() - ? dyn_cast(offset) - : rewriter.create( - subviewOp->getLoc(), getConstantIntValue(offset).value()); + Value offsetValue = dyn_cast(offset); + if (!offsetValue) { + offsetValue = rewriter.create( + subviewOp->getLoc(), getConstantIntValue(offset).value()); + } return offsetValue; } } @@ -199,7 +199,7 @@ void ConvertToPTOOpPass::runOnOperation() { moduleOp->walk([&](func::FuncOp funcOp) { RewritePatternSet patterns(ctx); populatePTOOpRewritingRule(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); }); } diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 91769e6298..087d8d62c6 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -26,6 +26,7 @@ // #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -199,8 +200,8 @@ static std::string getDtypeString(Type elemTy) { if (elemTy.isF32()) return "f32"; if (elemTy.isF16()) return "f16"; if (elemTy.isBF16()) return "bf16"; - if (elemTy.isFloat8E4M3FN()) return "f8e4m3"; - if (elemTy.isFloat8E5M2()) return "f8e5m2"; + if (isa(elemTy)) return "f8e4m3"; + if (isa(elemTy)) return "f8e5m2"; if (isa(elemTy)) return "hif8"; if (isa(elemTy)) return "f4e1m2x2"; if (isa(elemTy)) return "f4e2m1x2"; @@ -521,12 +522,13 @@ static bool getStaticIntFromValue(Value value, int64_t &out) { } static int64_t getStaticIntOrDynamic(OpFoldResult ofr) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) return intAttr.getInt(); return ShapedType::kDynamic; } - auto value = llvm::cast(ofr); + Value value = cast(ofr); int64_t result = ShapedType::kDynamic; if (getStaticIntFromValue(value, result)) return result; @@ -596,7 +598,8 @@ static void populateViewShapeAndStrides(Value value, shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); if (strides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(memrefTy, strides, offset))) { + if (succeeded( + mlir::pto::getPTOMemRefStridesAndOffset(memrefTy, strides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } @@ -644,7 +647,8 @@ static std::optional buildOperandTypeInfo(Value value) { info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); if (info.viewStrides.empty()) { int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { + if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + mrTy, info.viewStrides, offset))) { // strides populated — dynamic dims remain ShapedType::kDynamic. } } diff --git a/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp index ba12ff60b7..a21ebfd316 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/EventIdSolver.cpp @@ -32,7 +32,7 @@ int64_t EventIdSolver::getEventIdsNum(bool dontCalcEventIds) { calcEventIds(); } assert(!needRecalculateEventIds); - llvm::SmallDenseSet usedEventIds; + llvm::DenseSet usedEventIds; for (auto &node : nodes) { auto &eventIds = node->getEventIds(); assert(!eventIds.empty()); @@ -215,7 +215,7 @@ void EventIdSolver::addConflicts( llvm::SmallVector EventIdSolver::getAdjNodesUsedEventIds(EventIdNode *node) { - llvm::SmallDenseSet usedEventIds; + llvm::DenseSet usedEventIds; for (auto [otherNode, frq] : adjList[node]) { assert(frq > 0); auto &otherEventIds = otherNode->getEventIds(); diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index e6272cc46c..572995a870 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -64,12 +64,13 @@ static std::optional getConstInt(Value v) { } static std::optional getConstInt(OpFoldResult ofr) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto ia = dyn_cast(attr)) return ia.getInt(); return std::nullopt; } - return getConstInt(ofr.get()); + return getConstInt(cast(ofr)); } static unsigned elemByteSize(Type ty) { @@ -287,7 +288,8 @@ static std::optional inferFromStaticMemRefTy(MemRefType mrTy) { return std::nullopt; SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(mrTy, strideInts, offset))) + if (failed( + mlir::pto::getPTOMemRefStridesAndOffset(mrTy, strideInts, offset))) return std::nullopt; if (offset == ShapedType::kDynamic || llvm::any_of(strideInts, @@ -632,7 +634,8 @@ struct InferPTOLayoutPass SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcTy, strideInts, offset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcTy, strideInts, + offset)) || offset == ShapedType::kDynamic || llvm::any_of(strideInts, [](int64_t s) { return s == ShapedType::kDynamic; })) { diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 0fbc35b66f..2fd3a44526 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -115,7 +115,8 @@ static std::optional getKnownBLayout(Type ty) { if (auto memRefTy = dyn_cast(ty)) { SmallVector strides; int64_t offset = 0; - if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(memRefTy, strides, + offset)) || strides.size() != 2) { return std::nullopt; } diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index b28d6c2412..cdf87a29b7 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -153,7 +153,8 @@ getMemrefSubViewBaseAddresses(memref::SubViewOp op, MemRefType sourceType, SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(mlir::getStridesAndOffset(sourceType, strides, baseOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset)) || strides.size() != 2 || llvm::is_contained(strides, ShapedType::kDynamic)) return std::nullopt; @@ -257,7 +258,8 @@ static std::pair getStaticOffsetAndSize(Operation *op, Value s if (auto subView = dyn_cast(op)) { int64_t baseOffset; StrideVector strides; - if (failed(mlir::getStridesAndOffset(srcType, strides, baseOffset))) { + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcType, strides, + baseOffset))) { return {-1, -1}; } @@ -1028,7 +1030,8 @@ void PTOIRTranslator::UpdateMemrefSubViewAliasBufferInfo(memref::SubViewOp op) { SmallVector strides; int64_t baseOffset = ShapedType::kDynamic; - if (failed(mlir::getStridesAndOffset(sourceType, strides, baseOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset)) || strides.size() != 2) { UpdateConservativeAliasBufferInfo(result, source); return; diff --git a/lib/PTO/Transforms/LoweringSyncToPipe.cpp b/lib/PTO/Transforms/LoweringSyncToPipe.cpp index bc2672a0ba..15d02d7a98 100644 --- a/lib/PTO/Transforms/LoweringSyncToPipe.cpp +++ b/lib/PTO/Transforms/LoweringSyncToPipe.cpp @@ -138,7 +138,7 @@ struct LoweringSyncToPipe patterns.add( context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 64220acd97..10cf9da2bd 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -53,6 +54,11 @@ namespace { static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; +static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { + const APInt &value = attr.getValue(); + return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); +} + struct TileHandleMetadata { Value source; Value validRow; @@ -269,13 +275,13 @@ static bool getTilePointerStrides(TileBufConfigAttr configAttr, Type elemTy, if (auto blAttr = dyn_cast(configAttr.getBLayout())) blVal = static_cast(blAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); + blVal = static_cast(getIntegerAttrSignedValue(intAttr)); int32_t slVal = 0; if (auto slAttr = dyn_cast(configAttr.getSLayout())) slVal = static_cast(slAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); + slVal = static_cast(getIntegerAttrSignedValue(intAttr)); bool boxed = slVal != 0; int64_t innerRows = 1; @@ -283,7 +289,7 @@ static bool getTilePointerStrides(TileBufConfigAttr configAttr, Type elemTy, if (boxed) { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); if (elemBytes == 0) @@ -360,8 +366,8 @@ getMaterializedTileShape(MemRefType memTy, const TileHandleMetadata &meta) { SmallVector inheritedStrides; int64_t inheritedOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(sourceMrTy, inheritedStrides, - inheritedOffset)) || + if (failed(mlir::pto::getPTOMemRefStridesAndOffset( + sourceMrTy, inheritedStrides, inheritedOffset)) || inheritedStrides.size() < 2) return shape; @@ -429,12 +435,13 @@ static Value ensureI64(Value value, OpBuilder &builder, Location loc) { static Value materializeOffset(OpFoldResult ofr, OpBuilder &builder, Location loc) { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) - return makeI64Constant(builder, loc, intAttr.getInt()); + return makeI64Constant(builder, loc, getIntegerAttrSignedValue(intAttr)); return Value(); } - return ensureI64(ofr.get(), builder, loc); + return ensureI64(cast(ofr), builder, loc); } static Value addI64(Value lhs, Value rhs, OpBuilder &builder, Location loc) { @@ -474,7 +481,8 @@ static Value computeSubviewAddress(memref::SubViewOp subview, SmallVector sourceStrides; int64_t sourceOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(sourceTy, sourceStrides, sourceOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset( + sourceTy, sourceStrides, sourceOffset))) return Value(); auto mixedOffsets = subview.getMixedOffsets(); @@ -613,6 +621,45 @@ static Value lookupMaterializedTileHandle( return it->second; } +static void cloneBlockWithoutTerminator(Block *from, Block *to, + IRMapping &mapping) { + OpBuilder builder(to, to->end()); + Operation *terminator = from->getTerminator(); + for (Operation &op : *from) { + if (&op == terminator) + break; + builder.clone(op, mapping); + } +} + +static bool isOperationNestedIn(Operation *op, Operation *ancestor) { + for (; op; op = op->getParentOp()) { + if (op == ancestor) + return true; + } + return false; +} + +static bool isValueOwnedByOperation(Value value, Operation *owner) { + if (auto result = dyn_cast(value)) + return isOperationNestedIn(result.getOwner(), owner); + if (auto arg = dyn_cast(value)) + return isOperationNestedIn(arg.getOwner()->getParentOp(), owner); + return false; +} + +static void eraseTileHandlesOwnedBy(Operation *owner, + DenseMap &tileHandles) { + SmallVector keysToErase; + for (auto &entry : tileHandles) { + if (isValueOwnedByOperation(entry.first, owner) || + isValueOwnedByOperation(entry.second, owner)) + keysToErase.push_back(entry.first); + } + for (Value key : keysToErase) + tileHandles.erase(key); +} + static FailureOr materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { bool changed = false; @@ -629,6 +676,19 @@ materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { if (!thenYield || !elseYield) continue; + if (thenYield.getNumOperands() != ifOp.getNumResults() || + elseYield.getNumOperands() != ifOp.getNumResults()) { + ifOp.emitOpError("result count does not match branch yield operands"); + return failure(); + } + + SmallVector resultTypes(ifOp->getResultTypes()); + SmallVector thenYieldOperands(thenYield.getOperands().begin(), + thenYield.getOperands().end()); + SmallVector elseYieldOperands(elseYield.getOperands().begin(), + elseYield.getOperands().end()); + SmallVector materializedResults; + for (auto [idx, result] : llvm::enumerate(ifOp.getResults())) { if (!isLocalTileMemRef(result.getType())) continue; @@ -649,12 +709,48 @@ materializeSCFIfResults(ModuleOp module, DenseMap &tileHandles) { } Type tileTy = thenTile.getType(); - thenYield->setOperand(idx, thenTile); - elseYield->setOperand(idx, elseTile); - result.setType(tileTy); - tileHandles[result] = result; - changed = true; + resultTypes[idx] = tileTy; + thenYieldOperands[idx] = thenTile; + elseYieldOperands[idx] = elseTile; + materializedResults.push_back(idx); + } + + if (materializedResults.empty()) + continue; + + OpBuilder builder(ifOp); + auto newIf = builder.create( + ifOp.getLoc(), TypeRange(resultTypes), ifOp.getCondition(), + /*addThenBlock=*/true, /*addElseBlock=*/true); + newIf->setAttrs(ifOp->getAttrs()); + + IRMapping thenMapping; + cloneBlockWithoutTerminator(ifOp.thenBlock(), newIf.thenBlock(), + thenMapping); + builder.setInsertionPointToEnd(newIf.thenBlock()); + for (Value &operand : thenYieldOperands) + operand = thenMapping.lookupOrDefault(operand); + builder.create(thenYield.getLoc(), thenYieldOperands); + + IRMapping elseMapping; + cloneBlockWithoutTerminator(ifOp.elseBlock(), newIf.elseBlock(), + elseMapping); + builder.setInsertionPointToEnd(newIf.elseBlock()); + for (Value &operand : elseYieldOperands) + operand = elseMapping.lookupOrDefault(operand); + builder.create(elseYield.getLoc(), elseYieldOperands); + + for (auto [oldResult, newResult] : + llvm::zip_equal(ifOp.getResults(), newIf.getResults())) + oldResult.replaceAllUsesWith(newResult); + + for (unsigned idx : materializedResults) { + tileHandles[newIf.getResult(idx)] = newIf.getResult(idx); } + + eraseTileHandlesOwnedBy(ifOp, tileHandles); + ifOp.erase(); + changed = true; } return changed; @@ -675,6 +771,18 @@ materializeSCFForResults(ModuleOp module, DenseMap &tileHandles) { if (!yield) continue; + if (yield.getNumOperands() != forOp.getNumResults() || + forOp.getInitArgs().size() != forOp.getNumResults()) { + forOp.emitOpError("result count does not match iter/yield operands"); + return failure(); + } + + SmallVector initArgs(forOp.getInitArgs().begin(), + forOp.getInitArgs().end()); + SmallVector yieldOperands(yield.getOperands().begin(), + yield.getOperands().end()); + SmallVector materializedResults; + for (auto [idx, result] : llvm::enumerate(forOp.getResults())) { if (!isLocalTileMemRef(result.getType())) continue; @@ -702,15 +810,46 @@ materializeSCFForResults(ModuleOp module, DenseMap &tileHandles) { return failure(); } - Type tileTy = initTile.getType(); - forOp->setOperand(forOp.getNumControlOperands() + idx, initTile); - iterArg.setType(tileTy); - yield->setOperand(idx, yieldTile); - result.setType(tileTy); - tileHandles[iterArg] = iterArg; - tileHandles[result] = result; - changed = true; + initArgs[idx] = initTile; + yieldOperands[idx] = yieldTile; + materializedResults.push_back(idx); } + + if (materializedResults.empty()) + continue; + + OpBuilder builder(forOp); + auto newFor = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), initArgs); + newFor->setAttrs(forOp->getAttrs()); + + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newFor.getInductionVar()); + for (auto [oldArg, newArg] : + llvm::zip_equal(forOp.getBody()->getArguments().drop_front(), + newFor.getBody()->getArguments().drop_front())) + mapping.map(oldArg, newArg); + + cloneBlockWithoutTerminator(forOp.getBody(), newFor.getBody(), mapping); + builder.setInsertionPointToEnd(newFor.getBody()); + for (Value &operand : yieldOperands) + operand = mapping.lookupOrDefault(operand); + builder.create(yield.getLoc(), yieldOperands); + + for (auto [oldResult, newResult] : + llvm::zip_equal(forOp.getResults(), newFor.getResults())) + oldResult.replaceAllUsesWith(newResult); + + for (unsigned idx : materializedResults) { + BlockArgument newIterArg = newFor.getRegionIterArg(idx); + tileHandles[newIterArg] = newIterArg; + tileHandles[newFor.getResult(idx)] = newFor.getResult(idx); + } + + eraseTileHandlesOwnedBy(forOp, tileHandles); + forOp.erase(); + changed = true; } return changed; diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index a97d16158e..5b34a030e2 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -2439,7 +2439,7 @@ void PlanMemoryPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateBufferAddressToAllocOp(patterns, memPlan.GetBuffer2Offsets()); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f872e571ef..b2c0ca59b3 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -20,10 +20,6 @@ #include "PTO/IR/PTOSyncUtils.h" #include "PTO/Transforms/Passes.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Analysis/DataFlowFramework.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" @@ -123,6 +119,31 @@ static pto::AddressSpace getAddressSpaceOrGM(Attribute memorySpace) { return pto::AddressSpace::GM; } +static Type getEmitCVariableResultType(Type valueType) { + if (isa(valueType)) + return valueType; + return emitc::LValueType::get(valueType); +} + +static Value loadEmitCVariableIfNeeded(OpBuilder &builder, Location loc, + Value value) { + if (auto lvalueTy = dyn_cast(value.getType())) + return builder.create(loc, lvalueTy.getValueType(), value) + .getResult(); + return value; +} + +static Value getSourceEmitCVariable(Value value) { + if (value.getDefiningOp()) + return value; + if (auto loadOp = value.getDefiningOp()) { + Value operand = loadOp.getOperand(); + if (operand.getDefiningOp()) + return operand; + } + return {}; +} + [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = "__pto.lowered_set_validshape"; [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = @@ -143,6 +164,18 @@ enum TNotifyMteDrainMask : unsigned { static constexpr llvm::StringLiteral kLastUseAttrName = "pto.last_use"; static constexpr llvm::StringLiteral kLastUseMarkerPrefix = "PTOAS__LAST_USE__"; +static int64_t getAPIntSignedValue(const APInt &value) { + return value.getBitWidth() == 0 ? 0 : value.getSExtValue(); +} + +static uint64_t getAPIntUnsignedValue(const APInt &value) { + return value.getBitWidth() == 0 ? 0 : value.getZExtValue(); +} + +static int64_t getIntegerAttrSignedValue(IntegerAttr attr) { + return getAPIntSignedValue(attr.getValue()); +} + static SmallVector collectTileOperandNumbers(Operation *op) { SmallVector tileOperandNumbers; for (OpOperand &operand : op->getOpOperands()) { @@ -448,12 +481,9 @@ static bool isF8E8M0ElemType(Type elemTy) { } static std::string getEmitCScalarTypeToken(Type elemTy) { - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) + if (pto::isPTOFloat8E4M3LikeType(elemTy)) return "float8_e4m3_t"; - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) + if (pto::isPTOFloat8E5M2LikeType(elemTy)) return "float8_e5m2_t"; if (isF8E8M0ElemType(elemTy)) return "float8_e8m0_t"; @@ -771,7 +801,7 @@ static std::optional getEmitCTileTypeString(pto::TileBufType type) int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); return std::string("Tile<") + tileRoleToken(type.getMemorySpace(), elemTy, type.getConfigAttr()) + ", " + @@ -795,10 +825,9 @@ class PTOToEmitCTypeConverter : public TypeConverter { // 1. 基本类型 (f32, i32, index) // --------------------------------------------------------- addConversion([Ctx](FloatType type) -> Type { - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); @@ -1003,8 +1032,6 @@ class PTOToEmitCTypeConverter : public TypeConverter { addSourceMaterialization(materializeCast); addTargetMaterialization(materializeCast); - // Needed for region/block signature conversions (e.g. CFG block args). - addArgumentMaterialization(materializeCast); } }; @@ -1086,14 +1113,15 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); if (failed(dirTok)) return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); + int32_t localSlotNum = + initOp.getLocalSlotNumAttr() + ? static_cast( + getIntegerAttrSignedValue(initOp.getLocalSlotNumAttr())) + : initOp.getSlotNum(); + return buildTPipeToken( + static_cast(getIntegerAttrSignedValue(initOp.getFlagBaseAttr())), + *dirTok, initOp.getSlotSize(), initOp.getSlotNum(), localSlotNum, + initOp.getNosplitAttr() && initOp.getNosplitAttr().getValue()); } if (auto initOp = dyn_cast(op)) { @@ -1103,10 +1131,10 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); if (failed(dirTok)) return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); + return buildTPipeToken( + static_cast(getIntegerAttrSignedValue(initOp.getFlagBaseAttr())), + *dirTok, initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && initOp.getNosplitAttr().getValue()); } return failure(); @@ -1277,7 +1305,8 @@ static InterCoreSyncCallDesc buildInterCoreSyncSetCall( if (targetArch == PTOArch::A3) { auto indexTy = emitc::OpaqueType::get(ctx, "int64_t"); Value eventVal = - makeEmitCIntConstant(rewriter, loc, indexTy, eventIdAttr.getInt()); + makeEmitCIntConstant(rewriter, loc, indexTy, + getIntegerAttrSignedValue(eventIdAttr)); Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); InterCoreSyncCallDesc desc; @@ -2780,7 +2809,7 @@ static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, if (!first) os << ", "; first = false; - os << elem.getZExtValue(); + os << getAPIntUnsignedValue(elem); } os << "}"; os.flush(); @@ -3057,7 +3086,7 @@ struct ArithConstantToEmitC : public OpConversionPattern { } if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + std::string valStr = std::to_string(getIntegerAttrSignedValue(intAttr)); auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); rewriter.replaceOpWithNewOp(op, newType, constAttr); return success(); @@ -3419,14 +3448,15 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 std::optional extractStaticInt(OpFoldResult ofr) const { - if (auto attr = ofr.dyn_cast()) { + if (isa(ofr)) { + Attribute attr = cast(ofr); if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); + return getIntegerAttrSignedValue(intAttr); } else { - Value v = ofr.get(); + Value v = cast(ofr); if (auto cOp = v.getDefiningOp()) { if (auto iAttr = dyn_cast(cOp.getValue())) - return iAttr.getInt(); + return getIntegerAttrSignedValue(iAttr); } else if (auto idxOp = v.getDefiningOp()) { return idxOp.value(); } @@ -3494,13 +3524,15 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { - if (auto v = ofr.dyn_cast()) { + if (isa(ofr)) { + Value v = cast(ofr); Value rv = rewriter.getRemappedValue(v); return asIndex(rv); } - if (auto attr = ofr.dyn_cast()) { - if (auto ia = dyn_cast(attr)) - return mkIndex(ia.getValue().getSExtValue()); + if (isa(ofr)) { + Attribute attr = cast(ofr); + if (auto ia = dyn_cast(attr)) + return mkIndex(getIntegerAttrSignedValue(ia)); } return mkIndex(0); }; @@ -3513,7 +3545,9 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } else { SmallVector strideInts; int64_t offset = ShapedType::kDynamic; - bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); + bool useTypeStrides = + succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + srcType, strideInts, offset)); (void)offset; if (useTypeStrides) { for (int64_t s : strideInts) { @@ -3884,7 +3918,8 @@ static bool hasStaticShape(MemRefType mrTy) { static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, int64_t &offset) { - if (failed(getStridesAndOffset(mrTy, strides, offset))) { + if (failed( + mlir::pto::getPTOMemRefStridesAndOffset(mrTy, strides, offset))) { strides.clear(); int64_t stride = 1; ArrayRef shape = mrTy.getShape(); @@ -4275,7 +4310,7 @@ static FailureOr buildAsyncScratchTileValue( int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); Type elemTy = memTy.getElementType(); pto::BLayout blayout = getTileBufBLayoutValue(configAttr); @@ -4292,9 +4327,11 @@ static FailureOr buildAsyncScratchTileValue( Value tile = rewriter .create( - loc, emitc::OpaqueType::get(ctx, tileTypeStr), + loc, getEmitCVariableResultType( + emitc::OpaqueType::get(ctx, tileTypeStr)), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); Value scratchAddr = rewriter @@ -4354,9 +4391,11 @@ static FailureOr buildSyncAllWorkspaceTileValue( auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); Value tile = rewriter - .create(loc, tileEmitTy, + .create(loc, + getEmitCVariableResultType(tileEmitTy), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); Value rawPtr = workspace; auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); @@ -4385,7 +4424,7 @@ struct PointerCastConversion : public OpConversionPattern { static bool getIndexConst(Value v, int64_t &out) { if (auto cst = v.getDefiningOp()) { if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); + out = getIntegerAttrSignedValue(ia); return true; } } @@ -4481,7 +4520,7 @@ struct PointerCastConversion : public OpConversionPattern { if (!v) return false; if (auto cst = v.getDefiningOp()) { if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); + outVal = getIntegerAttrSignedValue(attr); return true; } } @@ -4548,7 +4587,8 @@ struct PointerCastConversion : public OpConversionPattern { std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + if (auto attr = dyn_cast(config.getSFractalSize())) + frVal = static_cast(getIntegerAttrSignedValue(attr)); int32_t padVal = 0; if (auto attr = dyn_cast(config.getPad())) @@ -4695,10 +4735,10 @@ struct PointerCastConversion : public OpConversionPattern { // 静态情况 (Tile v;) auto varOp = rewriter.create( loc, - tileType, + getEmitCVariableResultType(tileType), emitc::OpaqueAttr::get(ctx, "") ); - resultValue = varOp.getResult(); + resultValue = loadEmitCVariableIfNeeded(rewriter, loc, varOp.getResult()); } // TASSIGN: pto-isa expects an integral address. @@ -5933,7 +5973,7 @@ struct PTOSyncSetToEmitC : public OpConversionPattern { Value eventIdDyn = adaptor.getEventIdDyn(); int64_t fftsMode = 2; if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); + fftsMode = getIntegerAttrSignedValue(fftsModeAttr); if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) return rewriter.notifyMatchFailure( @@ -5973,7 +6013,7 @@ struct PTOSyncSetToEmitC : public OpConversionPattern { emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); if (needsMirrorPlus16) { auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); + getIntegerAttrSignedValue(eventIdAttr) + 16); emitSet(Value{}, plus16, /*isDynamic=*/false); } } else { @@ -6306,7 +6346,7 @@ struct PTOHistogramToEmitC : public OpConversionPattern { int64_t byte = 1; auto byteAttr = op.getByteAttr(); if (byteAttr) - byte = byteAttr.getInt(); + byte = getIntegerAttrSignedValue(byteAttr); if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; if (byteAttr && byte != legacyByte) @@ -6816,8 +6856,10 @@ struct PTOBuildAsyncSessionToEmitC Value session = rewriter .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + loc, getEmitCVariableResultType(sessionTy), + emitc::OpaqueAttr::get(ctx, "")) .getResult(); + session = loadEmitCVariableIfNeeded(rewriter, loc, session); auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); @@ -6825,14 +6867,28 @@ struct PTOBuildAsyncSessionToEmitC return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, std::to_string(value) + "u"); }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t syncId = op.getSyncIdAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getSyncIdAttr())) + : 0; uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + op.getBlockBytesAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getBlockBytesAttr())) + : 32 * 1024; uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + op.getCommBlockOffsetAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getCommBlockOffsetAttr())) + : 0; + uint64_t queueNum = op.getQueueNumAttr() + ? static_cast( + getIntegerAttrSignedValue(op.getQueueNumAttr())) + : 1; uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() + ? static_cast( + getIntegerAttrSignedValue( + op.getChannelGroupIdxAttr())) : UINT32_MAX; Value syncIdVal = makeU32Const(syncId); @@ -6846,12 +6902,13 @@ struct PTOBuildAsyncSessionToEmitC Value baseConfig = rewriter .create( - loc, baseConfigTy, + loc, getEmitCVariableResultType(baseConfigTy), emitc::OpaqueAttr::get( ctx, "{" + std::to_string(blockBytes) + "ULL, " + std::to_string(commBlockOffset) + "ULL, " + std::to_string(queueNum) + "u}")) .getResult(); + baseConfig = loadEmitCVariableIfNeeded(rewriter, loc, baseConfig); rewriter.create( loc, TypeRange{}, "pto::comm::BuildAsyncSession", @@ -6981,7 +7038,7 @@ static FailureOr buildCollectiveParallelGroup( firstTy); auto groupArray = cast>( rewriter - .create(loc, arrayTy, + .create(loc, getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(ctx, "{}")) .getResult()); @@ -7385,9 +7442,10 @@ struct PTODeclareGlobalToEmitC } } auto var = rewriter.create( - op.getLoc(), convertedType, + op.getLoc(), getEmitCVariableResultType(convertedType), emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); + rewriter.replaceOp( + op, loadEmitCVariableIfNeeded(rewriter, op.getLoc(), var.getResult())); return success(); } }; @@ -7408,9 +7466,10 @@ struct PTODeclareEventIdArrayToEmitC auto array = rewriter .create( - op.getLoc(), arrayTy, + op.getLoc(), getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(rewriter.getContext(), "")) .getResult(); + array = loadEmitCVariableIfNeeded(rewriter, op.getLoc(), array); rewriter.replaceOp(op, array); return success(); } @@ -7432,9 +7491,10 @@ struct PTOEventIdArrayGetToEmitC return rewriter.notifyMatchFailure(op, "failed to map eventid_array get result type"); - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); + auto subscript = rewriter.create( + op.getLoc(), emitc::LValueType::get(resultTy), array, + ValueRange{index}); + rewriter.replaceOpWithNewOp(op, resultTy, subscript); return success(); } }; @@ -7451,9 +7511,12 @@ struct PTOEventIdArraySetToEmitC Value index = peelUnrealized(adaptor.getIndex()); Value value = peelUnrealized(adaptor.getValue()); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + Value slot = rewriter + .create( + op.getLoc(), emitc::LValueType::get(value.getType()), + array, ValueRange{index}) + .getResult(); + rewriter.create(op.getLoc(), slot, value); rewriter.eraseOp(op); return success(); } @@ -7477,9 +7540,10 @@ struct PTODeclareLocalArrayToEmitC auto var = rewriter .create( - op.getLoc(), arrayTy, + op.getLoc(), getEmitCVariableResultType(arrayTy), emitc::OpaqueAttr::get(rewriter.getContext(), "")) .getResult(); + var = loadEmitCVariableIfNeeded(rewriter, op.getLoc(), var); rewriter.replaceOp(op, var); return success(); } @@ -7509,15 +7573,8 @@ struct PTOLocalArrayGetToEmitC indices.push_back(peelUnrealized(index)); auto sub = rewriter.create( - op.getLoc(), resultTy, array, indices); - auto snapshot = - rewriter - .create( - op.getLoc(), resultTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.create(op.getLoc(), snapshot, sub.getResult()); - rewriter.replaceOp(op, snapshot); + op.getLoc(), emitc::LValueType::get(resultTy), array, indices); + rewriter.replaceOpWithNewOp(op, resultTy, sub); return success(); } }; @@ -7537,9 +7594,9 @@ struct PTOLocalArraySetToEmitC Type elemTy = value.getType(); Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) + .create( + op.getLoc(), emitc::LValueType::get(elemTy), + adaptor.getArray(), adaptor.getIndices()) .getResult(); rewriter.create(op.getLoc(), slot, value); rewriter.eraseOp(op); @@ -7556,7 +7613,7 @@ static std::optional getStaticIndexLikeValue(Value value) { return cst.value(); if (auto cst = value.getDefiningOp()) { if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); + return getIntegerAttrSignedValue(intAttr); } return std::nullopt; } @@ -8039,9 +8096,11 @@ struct ReinterpretCastToEmitC : public OpConversionPattern(loc, tileType, + .create(loc, + getEmitCVariableResultType(tileType), emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); // Compute an integer address and assign it to the new tile. // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). @@ -9402,7 +9461,7 @@ struct PTOGatherToEmitC : public OpConversionPattern { std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; int64_t offset = 0; if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); + offset = getIntegerAttrSignedValue(offsetAttr); auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); @@ -9810,16 +9869,6 @@ struct PTOQuantToEmitC : public OpConversionPattern { Value src = peelUnrealized(adaptor.getSrc()); Value fp = peelUnrealized(adaptor.getFp()); - auto makeOpaquePtr = [&](Value value) -> Value { - auto opaqueTy = mlir::dyn_cast(value.getType()); - if (!opaqueTy) - return {}; - return rewriter - .create(loc, emitc::PointerType::get(opaqueTy), "&", - value) - .getResult(); - }; - auto quantTypeTok = [&]() -> StringRef { switch (op.getQuantType()) { case pto::QuantType::INT8_SYM: @@ -9839,7 +9888,24 @@ struct PTOQuantToEmitC : public OpConversionPattern { Value offsetPtr; if (op.getOffset()) { Value offset = peelUnrealized(adaptor.getOffset()); - offsetPtr = makeOpaquePtr(offset); + Type offsetValueTy = offset.getType(); + Value offsetLValue = offset; + if (auto lvalueTy = dyn_cast(offsetValueTy)) { + offsetValueTy = lvalueTy.getValueType(); + } else { + offsetLValue = + rewriter + .create( + loc, getEmitCVariableResultType(offsetValueTy), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + rewriter.create(loc, offsetLValue, offset); + } + offsetPtr = + rewriter + .create( + loc, emitc::PointerType::get(offsetValueTy), "&", offsetLValue) + .getResult(); } ArrayAttr templateArgs; @@ -9911,8 +9977,20 @@ struct PTOQuantMxToEmitC : public OpConversionPattern { op, "expected exp_zz operand to be emitc::OpaqueType"); auto makePtr = [&](Value v, emitc::OpaqueType ot) -> Value { + if (auto lvalueTy = dyn_cast(v.getType())) + return rewriter.create( + loc, emitc::PointerType::get(lvalueTy.getValueType()), + "&", v) + .getResult(); + + Value tmp = rewriter + .create( + loc, getEmitCVariableResultType(ot), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + rewriter.create(loc, tmp, v); return rewriter.create( - loc, emitc::PointerType::get(ot), "&", v) + loc, emitc::PointerType::get(ot), "&", tmp) .getResult(); }; @@ -11708,6 +11786,7 @@ struct PTOPrintToTPRINT : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); auto printFormatTok = [&](pto::PrintFormat format) -> StringRef { switch (format) { case pto::PrintFormat::Width8_Precision4: @@ -11720,7 +11799,7 @@ struct PTOPrintToTPRINT : public OpConversionPattern { llvm_unreachable("unknown PrintFormat"); }; - Value src = adaptor.getSrc(); + Value src = peelUnrealized(adaptor.getSrc()); if (isa(op.getSrc().getType()) || isa(op.getSrc().getType())) { src = maybeWrapGlobalMemrefAsGlobalTensor( @@ -11731,6 +11810,7 @@ struct PTOPrintToTPRINT : public OpConversionPattern { if (Value tmp = op->getNumOperands() > 1 ? op->getOperand(1) : Value()) { Value tmpValue = adaptor.getOperands().size() > 1 ? adaptor.getOperands()[1] : Value(); + tmpValue = peelUnrealized(tmpValue); if (isa(tmp.getType()) || isa(tmp.getType())) { tmpValue = maybeWrapGlobalMemrefAsGlobalTensor( @@ -11744,7 +11824,7 @@ struct PTOPrintToTPRINT : public OpConversionPattern { dyn_cast_or_null( op.getProperties().printFormat)) { templateArgVec.push_back(emitc::OpaqueAttr::get( - rewriter.getContext(), printFormatTok(formatAttr.getValue()))); + ctx, printFormatTok(formatAttr.getValue()))); } ArrayAttr templateArgs = templateArgVec.empty() ? ArrayAttr{} : rewriter.getArrayAttr(templateArgVec); @@ -11838,7 +11918,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { return false; if (auto cst = v.getDefiningOp()) { if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); + out = getIntegerAttrSignedValue(ia); return true; } } @@ -11856,13 +11936,13 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (auto blAttr = dyn_cast(configAttr.getBLayout())) blVal = static_cast(blAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); + blVal = static_cast(getIntegerAttrSignedValue(intAttr)); int32_t slVal = 0; if (auto slAttr = dyn_cast(configAttr.getSLayout())) slVal = static_cast(slAttr.getValue()); else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); + slVal = static_cast(getIntegerAttrSignedValue(intAttr)); bool boxed = slVal != 0; int64_t innerRows = 1; @@ -11870,7 +11950,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (boxed) { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); if (elemBytes == 0) @@ -12101,8 +12181,8 @@ struct PTOBindTileToEmitC : public OpConversionPattern { if (!pto::isPTOFloat4PackedType(elemTy) && subRows != ShapedType::kDynamic && subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && + succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + subMrTy, inheritedStrides, inheritedOffset)) && inheritedStrides.size() >= 2) { int64_t childRowStride = 0; int64_t childColStride = 0; @@ -12130,7 +12210,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); + fractal = static_cast(getIntegerAttrSignedValue(frAttr)); std::string padTok = "PadValue::Null"; if (auto padAttr = dyn_cast(configAttr.getPad())) { @@ -12293,9 +12373,12 @@ struct PTOBindTileToEmitC : public OpConversionPattern { .getResult(0); } - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(tileType), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); }; auto emitElemTypeToString = [&](Type elemTy) -> std::string { @@ -12523,8 +12606,10 @@ struct PTOAllocTileToEmitC tile = rewriter .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(ctx, "")) .getResult(); + tile = loadEmitCVariableIfNeeded(rewriter, loc, tile); } Value addr = adaptor.getAddr(); @@ -12567,10 +12652,12 @@ createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, if (!convertedTy) convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); } struct PTOTReshapeToEmitC : public OpConversionPattern { @@ -12767,10 +12854,12 @@ struct PTOMaterializeTileToEmitC .getResult(0); } - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); + Value tile = rewriter + .create( + loc, getEmitCVariableResultType(convertedTy), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + return loadEmitCVariableIfNeeded(rewriter, loc, tile); }; if (!isSubview && !forceDynamicValid && isTileLike(source)) { @@ -13520,9 +13609,7 @@ struct CFSwitchToCondBr : public OpRewritePattern { static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, - DataFlowSolver &solver, PTOArch targetArch) { - (void)solver; patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -13799,7 +13886,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); - populateSCFToEmitCConversionPatterns(patterns); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Keep CFG-style branches type-consistent when block argument types are // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); @@ -14014,7 +14101,7 @@ static AICORE inline void ptoas_auto_sync_tail( scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + (void)applyPatternsGreedily(mop, std::move(scfLoweringPatterns)); bool hasUnsupportedSCF = false; mop.walk([&](Operation *op) { @@ -14083,14 +14170,8 @@ static AICORE inline void ptoas_auto_sync_tail( target.addLegalDialect(); target.addLegalOp(); - auto solver = std::make_unique(); - solver->load(); - solver->load(); - if (failed(solver->initializeAndRun(getOperation()))) - return signalPassFailure(); - RewritePatternSet patterns(ctx); - populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); + populatePTOToEmitCPatterns(patterns, typeConverter, ctx, targetArch); // 4. 执行转换 if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { @@ -14115,6 +14196,10 @@ static AICORE inline void ptoas_auto_sync_tail( StringRef value = opaqueTy.getValue(); return value.contains("Tile<") || value.contains("ConvTile<"); }; + auto isLoweredIndexType = [](Type ty) { + auto opaqueTy = dyn_cast(ty); + return opaqueTy && opaqueTy.getValue() == "int64_t"; + }; llvm::SmallVector castsToErase; bool castCleanupFailed = false; @@ -14144,6 +14229,17 @@ static AICORE inline void ptoas_auto_sync_tail( return; } + // IndexType is lowered to int64_t for EmitC. SCF structural conversion + // can still materialize temporary index<->int64_t bridges; keeping them + // as emitc.cast leaves illegal index-typed EmitC IR for LLVM 21's C++ + // emitter, so fold the bridge back to the lowered value. + if ((isa(inTy) && isLoweredIndexType(outTy)) || + (isLoweredIndexType(inTy) && isa(outTy))) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + // SCF/CFG type conversion can transiently materialize pointer->memref // bridge casts. At this stage, the producing value is already in the // lowered EmitC pointer form; keep it and drop the bridge cast. @@ -14243,7 +14339,7 @@ static AICORE inline void ptoas_auto_sync_tail( return; if (callOp.getNumOperands() != 1 || callOp.getNumResults() != 1) return; - if (callOp.getOperand(0).getDefiningOp()) + if (getSourceEmitCVariable(callOp.getOperand(0))) tileDataReadsToSink.push_back(callOp); }); @@ -14294,6 +14390,20 @@ static AICORE inline void ptoas_auto_sync_tail( continue; // 这是一个赋值操作,不算有效使用 } } + if (auto load = dyn_cast(user)) { + bool loadOnlyFeedsDeadTileDataReads = true; + for (Operation *loadUser : load.getResult().getUsers()) { + auto call = dyn_cast(loadUser); + if (!call || call.getCallee() != "PTOAS__TILE_DATA" || + call.getNumResults() != 1 || + !call.getResult(0).use_empty()) { + loadOnlyFeedsDeadTileDataReads = false; + break; + } + } + if (loadOnlyFeedsDeadTileDataReads) + continue; + } // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 isRead = true; break; @@ -14308,7 +14418,14 @@ static AICORE inline void ptoas_auto_sync_tail( // 1. 先删除所有使用该变量的 TASSIGN llvm::SmallVector usersToErase; for (Operation* user : varOp.getResult().getUsers()) { - // 我们上面已经确认过,剩下的 user 只能是 TASSIGN + if (auto load = dyn_cast(user)) { + llvm::SmallVector loadUsersToErase; + for (Operation *loadUser : load.getResult().getUsers()) + loadUsersToErase.push_back(loadUser); + for (Operation *loadUser : loadUsersToErase) + loadUser->erase(); + } + // 我们上面已经确认过,剩下的 user 只能是 TASSIGN 或死 emitc.load usersToErase.push_back(user); } for (auto u : usersToErase) u->erase(); diff --git a/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp b/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp index 3403b47215..7735520d3c 100644 --- a/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp +++ b/lib/PTO/Transforms/PTOUnrollSIMTForPass.cpp @@ -134,10 +134,10 @@ struct PTOUnrollSIMTFor : public pto::impl::PTOUnrollSIMTForBase(ctx); GreedyRewriteConfig config; - config.maxIterations = 10; // loops may nest - config.strictMode = GreedyRewriteStrictness::ExistingOps; + config.setMaxIterations(10); // loops may nest + config.setStrictness(GreedyRewriteStrictness::ExistingOps); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config))) + if (failed(applyPatternsGreedily(func, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 2fa071e9f5..e86a0a3ad8 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1456,7 +1456,8 @@ static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { SmallVector srcStrides; int64_t srcOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(srcMrTy, srcStrides, + srcOffset))) srcStrides = computeCompactStrides(srcMrTy.getShape()); // Keep parent physical shape + strides for bound tile semantics. @@ -2249,7 +2250,8 @@ struct PTOViewToMemrefPass SmallVector staticStrides; int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(mrTy, staticStrides, offset)) && + if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + mrTy, staticStrides, offset)) && dimIndex < (int64_t)staticStrides.size() && staticStrides[dimIndex] != ShapedType::kDynamic) { rewriter.replaceOpWithNewOp( diff --git a/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp index b52f5cff50..cacd31565c 100644 --- a/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp +++ b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp @@ -528,9 +528,6 @@ struct FusionPlanPass : public pto::impl::FusionPlanBase { using pto::impl::FusionPlanBase::FusionPlanBase; FusionPlanPass() = default; - FusionPlanPass(const pto::FusionPlanOptions &options) { - enableShapeInference = options.enableShapeInference; - } void runOnOperation() override { func::FuncOp func = getOperation(); diff --git a/lib/PTO/Transforms/Utils.cpp b/lib/PTO/Transforms/Utils.cpp index e74b589337..7ce735669f 100644 --- a/lib/PTO/Transforms/Utils.cpp +++ b/lib/PTO/Transforms/Utils.cpp @@ -113,12 +113,10 @@ std::optional> getOperationAliasInfo(Operation *op) { dyn_cast(op)) { return std::make_pair(extractStridedMetadataOp.getBaseBuffer(), extractStridedMetadataOp.getViewSource()); - } else if (auto toMemrefOp = dyn_cast(op)) { - return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); + } else if (auto toBufferOp = dyn_cast(op)) { + return std::make_pair(toBufferOp.getBuffer(), toBufferOp.getTensor()); } else if (auto toTensorOp = dyn_cast(op)) { return std::make_pair(toTensorOp.getResult(), toTensorOp.getOperand()); - } else if (auto toMemrefOp = dyn_cast(op)) { - return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); } return std::nullopt; } diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 8362aea64b..0451203982 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -73,10 +73,9 @@ static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { return LLVM::LLVMFloat4E1M2x2Type::get(context); if (isa(type)) return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return LLVM::LLVMFloat8E4M3Type::get(context); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return LLVM::LLVMFloat8E5M2Type::get(context); return {}; } @@ -84,15 +83,13 @@ static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { static Type getLLVMCompatibleVectorType(ArrayRef shape, Type elementType, ArrayRef scalableDims = {}) { - if (shape.size() == 1 && !elementType.isIntOrIndexOrFloat()) - return LLVM::LLVMFixedVectorType::get(elementType, shape.front()); return VectorType::get(shape, elementType, scalableDims); } static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) - return LLVM::LLVMFixedVectorType::get( - LLVM::LLVMHiFloat8Type::get(builder.getContext()), 2); + return getLLVMCompatibleVectorType( + {2}, LLVM::LLVMHiFloat8Type::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -135,16 +132,6 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, vecType.getScalableDims()); } - if (auto vecType = dyn_cast(type)) { - Type normalizedElement = - normalizeGEPElementTypeForLLVMLowering(vecType.getElementType(), - builder); - if (normalizedElement == vecType.getElementType()) - return normalizePayloadTypeForLLVMLowering(type, builder); - return LLVM::LLVMFixedVectorType::get(normalizedElement, - vecType.getNumElements()); - } - return normalizePayloadTypeForLLVMLowering(type, builder); } @@ -177,12 +164,6 @@ static unsigned getNaturalByteAlignment(Type type) { elems *= dim; return elemAlign * static_cast(elems); } - if (auto vecType = dyn_cast(type)) { - unsigned elemAlign = getNaturalByteAlignment(vecType.getElementType()); - if (!elemAlign) - return 0; - return elemAlign * vecType.getNumElements(); - } if (auto intType = dyn_cast(type)) return llvm::divideCeil(unsigned(intType.getWidth()), 8u); if (pto::isPTOHiFloat8x2Type(type)) @@ -227,7 +208,6 @@ class VPTOTypeConverter final : public TypeConverter { }); addSourceMaterialization(materializeVPTOCast); addTargetMaterialization(materializeVPTOCast); - addArgumentMaterialization(materializeVPTOCast); } }; @@ -364,12 +344,11 @@ static std::string getMadRhsFragment(Type type) { } static bool isMadE4M3ElementType(Type type) { - return type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ(); + return pto::isPTOFloat8E4M3LikeType(type); } static bool isMadE5M2ElementType(Type type) { - return type.isFloat8E5M2() || type.isFloat8E5M2FNUZ(); + return pto::isPTOFloat8E5M2LikeType(type); } static std::string getMadDstFragment(Type type) { @@ -541,10 +520,9 @@ static std::string getLowPrecisionElementFragment(Type type) { return "f4e1m2x2"; if (isa(type)) return "f4e2m1x2"; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return "f8e4m3"; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return "f8e5m2"; return {}; } @@ -712,8 +690,6 @@ static Type getElementTypeFromVectorLike(Type type) { return vecType.getElementType(); if (auto vecType = dyn_cast(type)) return vecType.getElementType(); - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); return {}; } @@ -725,8 +701,6 @@ static std::optional getElementCountFromVectorLike(Type type) { return std::nullopt; return vecType.getShape().front(); } - if (auto vecType = dyn_cast(type)) - return vecType.getNumElements(); return std::nullopt; } @@ -1079,10 +1053,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; @@ -8286,7 +8259,7 @@ class LowerKeepOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{}, payloads, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("R", payloads.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); for (pto::KeepOp keep : llvm::reverse(keepOps)) @@ -8359,7 +8332,7 @@ class LowerResumeOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{asmResultType}, ValueRange{}, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("=R", resumeOps.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); @@ -10552,7 +10525,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, kernelModulePM.addPass( std::make_unique()); kernelModulePM.addPass(arith::createArithExpandOpsPass()); - kernelModulePM.addPass(createConvertSCFToCFPass()); + kernelModulePM.addPass(createSCFToControlFlowPass()); kernelModulePM.addPass(createArithToLLVMConversionPass()); kernelModulePM.addPass(createConvertIndexToLLVMPass()); kernelModulePM.addPass(createFinalizeMemRefToLLVMConversionPass()); diff --git a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp index 60f4d18ce7..799cbc1b0b 100644 --- a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp +++ b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp @@ -1885,7 +1885,7 @@ struct VPTOExpandWrapperOpsPass ExpandMadSemanticPattern, ExpandMadSemanticPattern, ExpandMadSemanticPattern>(&getContext()); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 35f8cc51a3..3c2ff6267c 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -74,10 +74,9 @@ static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { return LLVM::LLVMFloat4E1M2x2Type::get(context); if (isa(type)) return LLVM::LLVMFloat4E2M1x2Type::get(context); - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return LLVM::LLVMFloat8E4M3Type::get(context); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return LLVM::LLVMFloat8E5M2Type::get(context); return {}; } @@ -85,15 +84,13 @@ static Type getLowPrecisionLLVMType(Type type, MLIRContext *context) { static Type getLLVMCompatibleVectorType(ArrayRef shape, Type elementType, ArrayRef scalableDims = {}) { - if (shape.size() == 1 && !elementType.isIntOrIndexOrFloat()) - return LLVM::LLVMFixedVectorType::get(elementType, shape.front()); return VectorType::get(shape, elementType, scalableDims); } static Type normalizePayloadTypeForLLVMLowering(Type type, Builder &builder) { if (pto::isPTOHiFloat8x2Type(type)) - return LLVM::LLVMFixedVectorType::get( - LLVM::LLVMHiFloat8Type::get(builder.getContext()), 2); + return getLLVMCompatibleVectorType( + {2}, LLVM::LLVMHiFloat8Type::get(builder.getContext())); if (Type lowpType = getLowPrecisionLLVMType(type, builder.getContext())) return lowpType; @@ -136,16 +133,6 @@ static Type normalizeGEPElementTypeForLLVMLowering(Type type, vecType.getScalableDims()); } - if (auto vecType = dyn_cast(type)) { - Type normalizedElement = - normalizeGEPElementTypeForLLVMLowering(vecType.getElementType(), - builder); - if (normalizedElement == vecType.getElementType()) - return normalizePayloadTypeForLLVMLowering(type, builder); - return LLVM::LLVMFixedVectorType::get(normalizedElement, - vecType.getNumElements()); - } - return normalizePayloadTypeForLLVMLowering(type, builder); } @@ -178,12 +165,6 @@ static unsigned getNaturalByteAlignment(Type type) { elems *= dim; return elemAlign * static_cast(elems); } - if (auto vecType = dyn_cast(type)) { - unsigned elemAlign = getNaturalByteAlignment(vecType.getElementType()); - if (!elemAlign) - return 0; - return elemAlign * vecType.getNumElements(); - } if (auto intType = dyn_cast(type)) return llvm::divideCeil(unsigned(intType.getWidth()), 8u); if (pto::isPTOHiFloat8x2Type(type)) @@ -228,7 +209,6 @@ class VPTOTypeConverter final : public TypeConverter { }); addSourceMaterialization(materializeVPTOCast); addTargetMaterialization(materializeVPTOCast); - addArgumentMaterialization(materializeVPTOCast); } }; @@ -365,12 +345,11 @@ static std::string getMadRhsFragment(Type type) { } static bool isMadE4M3ElementType(Type type) { - return type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ(); + return pto::isPTOFloat8E4M3LikeType(type); } static bool isMadE5M2ElementType(Type type) { - return type.isFloat8E5M2() || type.isFloat8E5M2FNUZ(); + return pto::isPTOFloat8E5M2LikeType(type); } static std::string getMadDstFragment(Type type) { @@ -498,10 +477,9 @@ static std::string getLowPrecisionElementFragment(Type type) { return "f4e1m2x2"; if (isa(type)) return "f4e2m1x2"; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return "f8e4m3"; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return "f8e5m2"; return {}; } @@ -679,8 +657,6 @@ static Type getElementTypeFromVectorLike(Type type) { return vecType.getElementType(); if (auto vecType = dyn_cast(type)) return vecType.getElementType(); - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); return {}; } @@ -692,8 +668,6 @@ static std::optional getElementCountFromVectorLike(Type type) { return std::nullopt; return vecType.getShape().front(); } - if (auto vecType = dyn_cast(type)) - return vecType.getNumElements(); return std::nullopt; } @@ -1047,10 +1021,9 @@ static VcvtElemKind classifyVcvtElemType(Type type) { return VcvtElemKind::BF16; if (type.isF32()) return VcvtElemKind::F32; - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + if (pto::isPTOFloat8E4M3LikeType(type)) return VcvtElemKind::F8E4M3; - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + if (pto::isPTOFloat8E5M2LikeType(type)) return VcvtElemKind::F8E5M2; if (pto::isPTOHiFloat8Type(type)) return VcvtElemKind::HiF8; @@ -8228,7 +8201,7 @@ class LowerKeepOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{}, payloads, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("R", payloads.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); for (pto::KeepOp keep : llvm::reverse(keepOps)) @@ -8302,7 +8275,7 @@ class LowerResumeOpPattern final : public OpConversionPattern { op.getLoc(), TypeRange{asmResultType}, ValueRange{}, asmString, appendSimtKeepResumeClobbers( buildRepeatedInlineAsmConstraints("=R", resumeOps.size()), clobbers), - true, false, + true, false, LLVM::tailcallkind::TailCallKind::None, LLVM::AsmDialectAttr::get(op.getContext(), LLVM::AsmDialect::AD_ATT), ArrayAttr{}); @@ -10513,7 +10486,7 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, kernelModulePM.addPass( std::make_unique()); kernelModulePM.addPass(arith::createArithExpandOpsPass()); - kernelModulePM.addPass(createConvertSCFToCFPass()); + kernelModulePM.addPass(createSCFToControlFlowPass()); kernelModulePM.addPass(createArithToLLVMConversionPass()); kernelModulePM.addPass(createConvertIndexToLLVMPass()); kernelModulePM.addPass(createFinalizeMemRefToLLVMConversionPass()); diff --git a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp index dbaa9a780d..fc238fac5f 100644 --- a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp +++ b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp @@ -69,7 +69,7 @@ struct VPTOPtrCastCleanupPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index 44dcdc2cdb..404c6634aa 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -10,6 +10,7 @@ #pragma GCC diagnostic ignored "-Woverloaded-virtual" #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -109,7 +110,8 @@ static LogicalResult computeSubviewElementOffset(memref::SubViewOp op, SmallVector strides; int64_t baseOffset = 0; - if (failed(getStridesAndOffset(sourceType, strides, baseOffset))) + if (failed(mlir::pto::getPTOMemRefStridesAndOffset(sourceType, strides, + baseOffset))) return failure(); // The SSA source already names the base address after bind_tile/pointer_cast // normalization. A dynamic memref layout offset here is metadata we can @@ -808,7 +810,6 @@ struct VPTOPtrNormalizePass [](Type type) { return convertSubviewResultType(type); }); typeConverter.addTargetMaterialization(materializeUnrealizedCast); typeConverter.addSourceMaterialization(materializeUnrealizedCast); - typeConverter.addArgumentMaterialization(materializeUnrealizedCast); ConversionTarget target(*context); target.addLegalDialect str: @@ -66,9 +66,9 @@ class DocBlockMetadata: @dataclass(frozen=True) class DocTestDirective: mode: str - symbol: str | None = None - compile_kwargs: dict[str, object] | None = None - fixture: str | None = None + symbol: Optional[str] = None + compile_kwargs: Optional[dict[str, object]] = None + fixture: Optional[str] = None @dataclass(frozen=True) @@ -85,12 +85,12 @@ def expect(condition: bool, message: str) -> None: raise AssertionError(message) -def format_doc_context(path: Path, start_line: int, symbol: str | None = None) -> str: +def format_doc_context(path: Path, start_line: int, symbol: Optional[str] = None) -> str: symbol_text = symbol if symbol is not None else "" return f"{path}:{start_line} [symbol={symbol_text}]" -def fail_doc(path: Path, start_line: int, message: str, symbol: str | None = None) -> None: +def fail_doc(path: Path, start_line: int, message: str, symbol: Optional[str] = None) -> None: raise AssertionError(f"{format_doc_context(path, start_line, symbol)}: {message}") @@ -98,7 +98,7 @@ def iter_markdown_files(root: Path) -> Iterable[Path]: yield from sorted(root.glob("*.md")) -def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMetadata | None: +def parse_metadata_line(path: Path, line: str, line_number: int) -> Optional[DocBlockMetadata]: match = META_RE.match(line) if match is None: return None @@ -116,7 +116,7 @@ def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMeta return DocBlockMetadata(kind=kind, body=body, line=line_number, raw=line.rstrip("\n")) -def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlockMetadata | None: +def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> Optional[DocBlockMetadata]: candidate = fence_line - 2 while candidate >= 0 and not lines[candidate].strip(): candidate -= 1 @@ -128,7 +128,7 @@ def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlo return parse_metadata_line(path, line, candidate + 1) -def block_label(block: MarkdownCodeBlock, symbol: str | None = None) -> str: +def block_label(block: MarkdownCodeBlock, symbol: Optional[str] = None) -> str: return format_doc_context(block.path, block.start_line, symbol) @@ -313,9 +313,9 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: def execute_source( source: str, block: MarkdownCodeBlock, - symbol: str | None = None, + symbol: Optional[str] = None, *, - extra_namespace: dict[str, object] | None = None, + extra_namespace: Optional[dict[str, object]] = None, ) -> dict[str, object]: namespace: dict[str, object] = { "__builtins__": __builtins__, @@ -493,7 +493,7 @@ def scan_markdown_file(path: Path) -> MarkdownScanResult: block_language = "" block_start = 0 block_lines: list[str] = [] - metadata: DocBlockMetadata | None = None + metadata: Optional[DocBlockMetadata] = None for index, line in enumerate(lines, start=1): fence_match = FENCE_RE.match(line.rstrip("\n")) diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 2157acdb1a..fd1b391644 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -12,6 +12,7 @@ import re import sys from tempfile import TemporaryDirectory +from typing import Optional from unittest import mock @@ -35,7 +36,7 @@ def expect(condition: bool, message: str) -> None: raise AssertionError(message) -def expect_raises(exc_type, func, message_substring: str | None = None) -> Exception: +def expect_raises(exc_type, func, message_substring: Optional[str] = None) -> Exception: try: func() except exc_type as exc: diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index b9db235f9c..0b5a4b1df3 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -11,10 +11,16 @@ import functools import os from pathlib import Path +from typing import Optional from mlir import ir as _ods_ir from . import _pto_ops_gen as _pto_ops_gen +from ._ods_common import ( + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) def _candidate_pto_ext_dirs(): @@ -82,7 +88,7 @@ def _export_generated_symbols(): def get_op_result_or_value(value): - return getattr(_pto_ops_gen, "_get_op_result_or_value")(value) + return _get_op_result_or_value(value) def _export_optional_cext_symbol(name): @@ -177,7 +183,6 @@ def _export_optional_cext_symbol(name): _ptr_type_get_impl = PtrType.get -_ods_get_default_loc_context = getattr(_pto_ops_gen, "_ods_get_default_loc_context") def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): @@ -865,7 +870,7 @@ def _init_inferred(self, source, offsets, sizes, args, loc, ip): sizes, *args = args if args: raise TypeError(f"too many positional arguments: {len(args)}") - source_value = _pto_ops_gen._get_op_result_or_value(source) + source_value = _get_op_result_or_value(source) source_type = source_value.type result = PartitionTensorViewType.get(source_type.rank, source_type.element_type) self._init_explicit(result, source_value, offsets, sizes, (), loc, ip) @@ -882,9 +887,9 @@ def _init_explicit(self, result, source, offsets, sizes, args, loc, ip): if args: raise TypeError(f"too many positional arguments: {len(args)}") operands = [ - _pto_ops_gen._get_op_result_or_value(source), - _pto_ops_gen._get_op_results_or_values(offsets), - _pto_ops_gen._get_op_results_or_values(sizes), + _get_op_result_or_value(source), + _get_op_results_or_values(offsets), + _get_op_results_or_values(sizes), ] op = self.build_generic( attributes={}, @@ -929,12 +934,12 @@ def _is_mask_pattern(value): def _value_type(value): try: - return _pto_ops_gen._get_op_result_or_value(value).type + return _get_op_result_or_value(value).type except Exception: return None def _matches_src_type(value): - src_value = _pto_ops_gen._get_op_result_or_value(src) + src_value = _get_op_result_or_value(src) value_type = _value_type(value) return value_type is not None and value_type == src_value.type @@ -1151,9 +1156,9 @@ class _VKernelCompileError(Exception): @_dataclass class _VKValue: - name: str | None = None - type: _VKernelType | None = None - literal: object | None = None + name: Optional[str] = None + type: Optional[_VKernelType] = None + literal: Optional[object] = None def render_type(self): if self.type is None: diff --git a/test/lit/pto/async_put_get_emitc.pto b/test/lit/pto/async_put_get_emitc.pto index b191dbd0bb..30998f7c3e 100644 --- a/test/lit/pto/async_put_get_emitc.pto +++ b/test/lit/pto/async_put_get_emitc.pto @@ -16,10 +16,13 @@ module { // A3-LABEL: AICORE void async_put_get( // A3: Tile [[SCRATCH:v[0-9]+]]; -// A3: TASSIGN([[SCRATCH]], [[SCRATCH_ADDR:v[0-9]+]]); +// A3: Tile [[SCRATCH_COPY:v[0-9]+]] = [[SCRATCH]]; +// A3: TASSIGN([[SCRATCH_COPY]], [[SCRATCH_ADDR:v[0-9]+]]); // A3: pto::comm::AsyncSession [[SESSION:v[0-9]+]]; +// A3: pto::comm::AsyncSession [[SESSION_COPY:v[0-9]+]] = [[SESSION]]; // A3: pto::comm::sdma::SdmaBaseConfig [[CFG:v[0-9]+]] = {32768ULL, 0ULL, 1u}; -// A3: pto::comm::BuildAsyncSession([[SCRATCH]], {{.*}}, [[SESSION]], {{.*}}, [[CFG]], {{.*}}); +// A3: pto::comm::sdma::SdmaBaseConfig [[CFG_COPY:v[0-9]+]] = [[CFG]]; +// A3: pto::comm::BuildAsyncSession([[SCRATCH_COPY]], {{.*}}, [[SESSION_COPY]], {{.*}}, [[CFG_COPY]], {{.*}}); // A3: using [[SHAPETY:.*]] = pto::Shape<1, 1, 1, 1, 128>; // A3: using [[STRIDETY:.*]] = pto::Stride<128, 128, 128, 128, 1>; // A3: constexpr pto::Layout [[LAYOUT:.*]] = pto::Layout::ND; @@ -32,5 +35,5 @@ module { // A3: [[GLTNSRTY]] [[GT1:v[0-9]+]] = [[GLTNSRTY]]({{.*}}, [[SHAPE1]], [[STRIDE1]]); // A3: pto::comm::AsyncEvent [[PUT_EVT:v[0-9]+]] = pto::comm::TPUT_ASYNC( // A3: pto::comm::AsyncEvent [[GET_EVT:v[0-9]+]] = pto::comm::TGET_ASYNC( -// A3: bool [[PUT_DONE:v[0-9]+]] = [[PUT_EVT]].Wait([[SESSION]]); -// A3: bool [[GET_DONE:v[0-9]+]] = [[GET_EVT]].Test([[SESSION]]); +// A3: bool [[PUT_DONE:v[0-9]+]] = [[PUT_EVT]].Wait([[SESSION_COPY]]); +// A3: bool [[GET_DONE:v[0-9]+]] = [[GET_EVT]].Test([[SESSION_COPY]]); diff --git a/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto b/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto index eccfe3596c..e54c8f9e18 100644 --- a/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto +++ b/test/lit/pto/backedge_nested_same_pipe_prune_regression.pto @@ -5,7 +5,7 @@ // the wider same-pipe pair can be removed before event-id allocation. // // CHECK-LABEL: AICORE void backedge_nested_same_pipe_prune() -// CHECK: for (size_t {{v[0-9]+}} = +// CHECK: for (int64_t {{[ij][0-9]+}} = // CHECK-NEXT: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[INNER:[0-9]+]]); // CHECK-NEXT: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK-NEXT: TMOV({{v[0-9]+}}, {{v[0-9]+}}); diff --git a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto index 2c946df121..953c70ed44 100644 --- a/test/lit/pto/emitc_tile_data_sink_after_tassign.pto +++ b/test/lit/pto/emitc_tile_data_sink_after_tassign.pto @@ -25,11 +25,13 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind(v4); -// CHECK: TASSIGN(v3, v5); -// CHECK: __ubuf__ float* v6 = v3.data(); -// CHECK: sink_ptr(v6); +// CHECK: Tile [[ADDR_BASE:v[0-9]+]]; +// CHECK-NEXT: Tile [[ADDR_TILE:v[0-9]+]] = [[ADDR_BASE]]; +// CHECK-NEXT: TASSIGN([[ADDR_TILE]], v1); +// CHECK: Tile [[SINK_BASE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SINK_TILE:v[0-9]+]] = [[SINK_BASE]]; +// CHECK-NEXT: __ubuf__ float* [[ADDR_PTR:v[0-9]+]] = [[ADDR_TILE]].data(); +// CHECK-NEXT: uint64_t [[ADDR_BITS:v[0-9]+]] = reinterpret_cast([[ADDR_PTR]]); +// CHECK-NEXT: TASSIGN([[SINK_TILE]], [[ADDR_BITS]]); +// CHECK-NEXT: __ubuf__ float* [[SINK_PTR:v[0-9]+]] = [[SINK_TILE]].data(); +// CHECK-NEXT: sink_ptr([[SINK_PTR]]); diff --git a/test/lit/pto/eventid_array_dyn_sync.pto b/test/lit/pto/eventid_array_dyn_sync.pto index b63b3e0d4b..e969b767f2 100644 --- a/test/lit/pto/eventid_array_dyn_sync.pto +++ b/test/lit/pto/eventid_array_dyn_sync.pto @@ -15,9 +15,11 @@ module { // CHECK-LABEL: AICORE void eventid_array_dyn_sync() { // CHECK: const int64_t {{v[0-9]+}} = 0; -// CHECK: PTOAS_EventIdArray<4> {{v[0-9]+}}; -// CHECK: {{v[0-9]+}}[{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: event_t {{v[0-9]+}} = (event_t) {{v[0-9]+}}[{{v[0-9]+}}]; +// CHECK: PTOAS_EventIdArray<4> [[ARR:v[0-9]+]]; +// CHECK: PTOAS_EventIdArray<4> [[ARR_VAL:v[0-9]+]] = [[ARR]]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[EID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: event_t {{v[0-9]+}} = (event_t) [[EID]]; // CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, {{v[0-9]+}}); -// CHECK: event_t {{v[0-9]+}} = (event_t) {{v[0-9]+}}[{{v[0-9]+}}]; +// CHECK: event_t {{v[0-9]+}} = (event_t) [[EID]]; // CHECK: wait_flag(PIPE_MTE2, PIPE_MTE3, {{v[0-9]+}}); diff --git a/test/lit/pto/eventid_array_get_set_get.pto b/test/lit/pto/eventid_array_get_set_get.pto index 6de754b085..7578c7a403 100644 --- a/test/lit/pto/eventid_array_get_set_get.pto +++ b/test/lit/pto/eventid_array_get_set_get.pto @@ -18,9 +18,12 @@ module { // CHECK-LABEL: AICORE void eventid_array_get_set_get() { // CHECK: PTOAS_EventIdArray<4> [[ARR:v[0-9]+]]; -// CHECK: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: event_t [[FIRST:v[0-9]+]] = (event_t) [[ARR]][{{v[0-9]+}}]; +// CHECK: PTOAS_EventIdArray<4> [[ARR_VAL:v[0-9]+]] = [[ARR]]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[FIRST_ID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: [[ARR_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: int64_t [[SECOND_ID:v[0-9]+]] = [[ARR_VAL]][{{v[0-9]+}}]; +// CHECK: event_t [[FIRST:v[0-9]+]] = (event_t) [[FIRST_ID]]; // CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, [[FIRST]]); -// CHECK: event_t [[SECOND:v[0-9]+]] = (event_t) [[ARR]][{{v[0-9]+}}]; +// CHECK: event_t [[SECOND:v[0-9]+]] = (event_t) [[SECOND_ID]]; // CHECK: wait_flag(PIPE_MTE2, PIPE_MTE3, [[SECOND]]); diff --git a/test/lit/pto/eventid_array_no_cse.pto b/test/lit/pto/eventid_array_no_cse.pto index a589b5da00..8d18b44176 100644 --- a/test/lit/pto/eventid_array_no_cse.pto +++ b/test/lit/pto/eventid_array_no_cse.pto @@ -19,6 +19,8 @@ module { // CHECK-LABEL: AICORE void eventid_array_no_cse() { // CHECK: PTOAS_EventIdArray<4> [[ARR0:v[0-9]+]]; +// CHECK: PTOAS_EventIdArray<4> [[ARR0_VAL:v[0-9]+]] = [[ARR0]]; // CHECK: PTOAS_EventIdArray<4> [[ARR1:v[0-9]+]]; -// CHECK: [[ARR0]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: [[ARR1]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: PTOAS_EventIdArray<4> [[ARR1_VAL:v[0-9]+]] = [[ARR1]]; +// CHECK: [[ARR0_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; +// CHECK: [[ARR1_VAL]][{{v[0-9]+}}] = {{v[0-9]+}}; diff --git a/test/lit/pto/issue428_cube_sync_regression.pto b/test/lit/pto/issue428_cube_sync_regression.pto index e5ed4f7917..335c2c9e26 100644 --- a/test/lit/pto/issue428_cube_sync_regression.pto +++ b/test/lit/pto/issue428_cube_sync_regression.pto @@ -5,7 +5,7 @@ // - Preserve the tail drain before the auto tail barrier helper. // // CHECK-LABEL: tri_inv_block2x2_fp16( -// CHECK: for (size_t [[LOOP0:v[0-9]+]] = +// CHECK: for (int64_t [[LOOP0:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); @@ -37,7 +37,7 @@ // CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); // CHECK: set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); // CHECK: } -// CHECK: for (size_t [[LOOP1:v[0-9]+]] = +// CHECK: for (int64_t [[LOOP1:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); diff --git a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto index cd8c9528b0..7748f8cba8 100644 --- a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto +++ b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression.pto @@ -16,7 +16,7 @@ // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); -// CHECK: for (size_t [[IV:v[0-9]+]] = +// CHECK: for (int64_t [[IV:[ij][0-9]+]] = // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); // CHECK: TMOV({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); @@ -26,7 +26,7 @@ // CHECK: set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); // CHECK: wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); // CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); -// CHECK: if ((int64_t) [[IV]] == +// CHECK: if ([[IV]] == // CHECK: } else { // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); // CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); diff --git a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto index 019a12654b..323d7de1b0 100644 --- a/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto +++ b/test/lit/pto/issue454_loop_if_else_loop_carried_sync_regression_gss.pto @@ -5,8 +5,8 @@ // // CHECK-DAG: pipe_barrier(PIPE_ALL) // CHECK-LABEL: AICORE void loop_if_else_loop_carried_sync() -// CHECK: for (size_t -// CHECK: if ((int64_t) +// CHECK: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: } else { // CHECK: #endif // __DAV_CUBE__ diff --git a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto index b673740f52..e7f74de69c 100644 --- a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto +++ b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto @@ -5,9 +5,9 @@ // - Inner-loop and outer-loop event chains must both keep their set/wait handshake. // // CHECK-LABEL: AICORE void nested_loop_same_pipe_pair() -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT:[0-9]+]]); -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN:[0-9]+]]); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN]]); // CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT]]); diff --git a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto index c6aef83ce1..2139d89bf6 100644 --- a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto +++ b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression_gss.pto @@ -8,7 +8,7 @@ // - Inner-loop and outer-loop event chains must both keep their set/wait handshake. // // CHECK-LABEL: AICORE void nested_loop_same_pipe_pair() -// CHECK-COUNT-2: for (size_t +// CHECK-COUNT-2: for (int64_t // CHECK: TMATMUL( // CHECK: #endif // __DAV_CUBE__ diff --git a/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto b/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto index 7ad2387278..58e1c7b739 100644 --- a/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto +++ b/test/lit/pto/issue533_loop_zero_trip_sync_regression.pto @@ -4,7 +4,7 @@ // a post-loop vector consumer must wait on MTE2->V before TROWEXPANDDIV. // // CHECK-LABEL: AICORE void qwen3_scope2_incore_5 -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: } // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[POST:[0-9]+]]); // CHECK: TROWEXPANDDIV diff --git a/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto b/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto index b8d73b19dc..56e75b410b 100644 --- a/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto +++ b/test/lit/pto/issue533_loop_zero_trip_sync_regression_gss.pto @@ -4,7 +4,7 @@ // a post-loop vector consumer must wait on MTE2->V before TROWEXPANDDIV. // // CHECK-LABEL: AICORE void qwen3_scope2_incore_5 -// CHECK: for (size_t +// CHECK: for (int64_t // CHECK: } // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[POST:[0-9]+]]); // CHECK: TROWEXPANDDIV diff --git a/test/lit/pto/issue556_tpop_live_values_no_alias.pto b/test/lit/pto/issue556_tpop_live_values_no_alias.pto index de1734a5c2..ddd179e75b 100644 --- a/test/lit/pto/issue556_tpop_live_values_no_alias.pto +++ b/test/lit/pto/issue556_tpop_live_values_no_alias.pto @@ -58,8 +58,10 @@ module { // CHECK-LABEL: AICORE void issue556_tpop_live_values_no_alias( // CHECK: Tile [[POP0:v[0-9]+]]; -// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP0]]); +// CHECK: Tile [[POP0_COPY:v[0-9]+]] = [[POP0]]; +// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP0_COPY]]); // CHECK: Tile [[POP1:v[0-9]+]]; -// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP1]]); -// CHECK: TMOV({{v[0-9]+}}, [[POP0]]); -// CHECK: TMOV({{v[0-9]+}}, [[POP1]]); +// CHECK: Tile [[POP1_COPY:v[0-9]+]] = [[POP1]]; +// CHECK: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>({{v[0-9]+}}, [[POP1_COPY]]); +// CHECK: TMOV({{v[0-9]+}}, [[POP0_COPY]]); +// CHECK: TMOV({{v[0-9]+}}, [[POP1_COPY]]); diff --git a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto index 9c5b00dd1f..3690ca6114 100644 --- a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto +++ b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto @@ -10,7 +10,7 @@ // CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE0:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY]]); -// CHECK-NEXT: for (size_t +// CHECK-NEXT: for (int64_t // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0:[0-9]+]]); // CHECK-NEXT: TLOAD( // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1:[0-9]+]]); diff --git a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto index 3696e084b6..a14b12d9a0 100644 --- a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto +++ b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape.pto @@ -28,9 +28,11 @@ module attributes {pto.target_arch = "a5"} { // CHECK: Tile [[SRC_ORIG:v[0-9]+]] = Tile // CHECK: Tile [[DST_ORIG:v[0-9]+]] = Tile -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK-NEXT: TRESHAPE([[SRC]], [[SRC_ORIG]]); -// CHECK: Tile [[DST:v[0-9]+]]; +// CHECK: Tile [[DST_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[DST:v[0-9]+]] = [[DST_STORAGE]]; // CHECK-NEXT: TRESHAPE([[DST]], [[DST_ORIG]]); // CHECK: int64_t [[SRC_ROW:v[0-9]+]] = [[SRC_ORIG]].GetValidRow(); // CHECK-NEXT: int64_t [[SRC_COL:v[0-9]+]] = [[SRC_ORIG]].GetValidCol(); diff --git a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto index 140f06b665..e3c3864d05 100644 --- a/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto +++ b/test/lit/pto/issue686_a5_tmov_treshape_dynamic_valid_shape_level2.pto @@ -27,9 +27,11 @@ module attributes {pto.target_arch = "a5"} { // CHECK-LABEL: AICORE void a5_tmov_treshape_dynamic_valid_shape_level2() // CHECK: Tile [[SRC_ORIG:v[0-9]+]] // CHECK: Tile [[DST_ORIG:v[0-9]+]] -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK-NEXT: TRESHAPE([[SRC]], [[SRC_ORIG]]); -// CHECK: Tile [[DST:v[0-9]+]]; +// CHECK: Tile [[DST_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[DST:v[0-9]+]] = [[DST_STORAGE]]; // CHECK-NEXT: TRESHAPE([[DST]], [[DST_ORIG]]); // CHECK: int64_t [[SRC_ROW:v[0-9]+]] = [[SRC_ORIG]].GetValidRow(); // CHECK-NEXT: int64_t [[SRC_COL:v[0-9]+]] = [[SRC_ORIG]].GetValidCol(); diff --git a/test/lit/pto/issue713_local_array_get_snapshot.pto b/test/lit/pto/issue713_local_array_get_snapshot.pto index 28ffc34b90..c549c233d7 100644 --- a/test/lit/pto/issue713_local_array_get_snapshot.pto +++ b/test/lit/pto/issue713_local_array_get_snapshot.pto @@ -37,11 +37,8 @@ module { // CPP: int32_t [[ARR:v[0-9]+]][2]; // CPP: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; // CPP: [[ARR]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CPP: int32_t [[CUR:v[0-9]+]]; -// CPP-NEXT: [[CUR]] = [[ARR]][{{v[0-9]+}}]; -// CPP-NEXT: uint32_t [[NEW_U:v[0-9]+]] = (uint32_t) [[CUR]]; -// CPP-NEXT: [[ARR]][{{v[0-9]+}}] = {{.*}}[[NEW_U]]{{.*}}; +// CPP: int32_t [[CUR:v[0-9]+]] = [[ARR]][{{v[0-9]+}}]; +// CPP-NEXT: [[ARR]][{{v[0-9]+}}] = {{.*}}(uint32_t) [[CUR]]{{.*}}; // CPP-NOT: [[ARR]][ -// CPP: uint32_t [[SUM_U:v[0-9]+]] = (uint32_t) [[CUR]]; -// CPP-NEXT: int32_t [[SUM:v[0-9]+]] = {{.*}}[[SUM_U]]{{.*}}; +// CPP: int32_t [[SUM:v[0-9]+]] = {{.*}}(uint32_t) [[CUR]]{{.*}}; // CPP-NEXT: [[OUT]][{{v[0-9]+}}] = [[SUM]]; diff --git a/test/lit/pto/local_array_1d_emitc.pto b/test/lit/pto/local_array_1d_emitc.pto index d38c2ca437..c90943f247 100644 --- a/test/lit/pto/local_array_1d_emitc.pto +++ b/test/lit/pto/local_array_1d_emitc.pto @@ -25,6 +25,5 @@ module { // CHECK-LABEL: local_array_1d // CHECK: int32_t [[A:v[0-9]+]][16]; // CHECK: [[A]][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: int32_t [[R:v[0-9]+]]; -// CHECK: [[R]] = [[A]][{{v[0-9]+}}]; +// CHECK: int32_t [[R:v[0-9]+]] = [[A]][{{v[0-9]+}}]; // CHECK: [[A]][{{v[0-9]+}}] = [[R]]; diff --git a/test/lit/pto/local_array_2d_emitc.pto b/test/lit/pto/local_array_2d_emitc.pto index 222f2d71de..584ed6efa7 100644 --- a/test/lit/pto/local_array_2d_emitc.pto +++ b/test/lit/pto/local_array_2d_emitc.pto @@ -25,6 +25,5 @@ module { // CHECK-LABEL: local_array_2d // CHECK: float [[M:v[0-9]+]][8][8]; // CHECK: [[M]][{{v[0-9]+}}][{{v[0-9]+}}] = {{v[0-9]+}}; -// CHECK: float [[R:v[0-9]+]]; -// CHECK: [[R]] = [[M]][{{v[0-9]+}}][{{v[0-9]+}}]; +// CHECK: float [[R:v[0-9]+]] = [[M]][{{v[0-9]+}}][{{v[0-9]+}}]; // CHECK: [[M]][{{v[0-9]+}}][{{v[0-9]+}}] = [[R]]; diff --git a/test/lit/pto/local_array_get_rvalue_emitc.pto b/test/lit/pto/local_array_get_rvalue_emitc.pto index 91140eec83..132e14ca47 100644 --- a/test/lit/pto/local_array_get_rvalue_emitc.pto +++ b/test/lit/pto/local_array_get_rvalue_emitc.pto @@ -24,7 +24,5 @@ module { // CHECK-LABEL: local_array_get_rvalue // CHECK: int32_t [[A:v[0-9]+]][16]; -// CHECK: int32_t [[R:v[0-9]+]]; -// CHECK: [[R]] = [[A]][{{v[0-9]+}}]; -// CHECK: uint32_t [[R_U:v[0-9]+]] = (uint32_t) [[R]]; -// CHECK: [[A]][{{v[0-9]+}}] = {{.*}}[[R_U]]{{.*}}; +// CHECK: int32_t [[R:v[0-9]+]] = [[A]][{{v[0-9]+}}]; +// CHECK: [[A]][{{v[0-9]+}}] = {{.*}}(uint32_t) [[R]]{{.*}}; diff --git a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto index d35962be36..e15e4c82b0 100644 --- a/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto +++ b/test/lit/pto/mgather_mscatter_a5_attrs_emitc.pto @@ -86,12 +86,14 @@ module attributes {pto.target_arch = "a5"} { } } +// CHECK-LABEL: AICORE void mgather_emitc // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); -// CHECK: __ubuf__ int32_t* [[MGATHER_IDX:v[0-9]+]] = {{.*}}; -// CHECK: __ubuf__ float* [[MGATHER_DST:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ int32_t* [[MGATHER_IDX:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ float* [[MGATHER_DST:v[0-9]+]] = {{.*}}; // CHECK: MGATHER([[MGATHER_DST]], {{v[0-9]+}}, [[MGATHER_IDX]]); +// CHECK-LABEL: AICORE void mscatter_emitc // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); // CHECK: TLOAD({{v[0-9]+}}, {{v[0-9]+}}); -// CHECK: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; -// CHECK: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ float* [[MSCATTER_SRC:v[0-9]+]] = {{.*}}; +// CHECK-DAG: __ubuf__ int32_t* [[MSCATTER_IDX:v[0-9]+]] = {{.*}}; // CHECK: MSCATTER({{v[0-9]+}}, [[MSCATTER_SRC]], [[MSCATTER_IDX]]); diff --git a/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto index 31f8d27073..eab2597534 100644 --- a/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto +++ b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto @@ -75,7 +75,8 @@ module { // EMITC: AICORE void const_preload_dyn_loop_select // EMITC: for ( -// EMITC-NEXT: Tile [[TILE:v[0-9]+]]; +// EMITC-NEXT: Tile [[TILE_STORAGE:v[0-9]+]]; +// EMITC-NEXT: Tile [[TILE:v[0-9]+]] = [[TILE_STORAGE]]; // EMITC-NEXT: uint64_t [[ADDR:v[0-9]+]] = (uint64_t) // EMITC-NEXT: TASSIGN([[TILE]], [[ADDR]]); // EMITC: TLOAD([[TILE]], diff --git a/test/lit/pto/syncfinder_zero_loop_if_probe.pto b/test/lit/pto/syncfinder_zero_loop_if_probe.pto index 9f04615a22..e68216fa35 100644 --- a/test/lit/pto/syncfinder_zero_loop_if_probe.pto +++ b/test/lit/pto/syncfinder_zero_loop_if_probe.pto @@ -8,8 +8,8 @@ // CHECK: TLOAD // CHECK: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP]]); -// CHECK-NEXT: for (size_t -// CHECK: if ((int64_t) +// CHECK-NEXT: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: TMOV // CHECK: } // CHECK: } diff --git a/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto b/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto index 9d5fe9dd91..f174d48893 100644 --- a/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto +++ b/test/lit/pto/syncfinder_zero_loop_if_probe_gss.pto @@ -8,8 +8,8 @@ // CHECK: TLOAD // CHECK: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID[[LOOP]]); -// CHECK-NEXT: for (size_t -// CHECK: if ((int64_t) +// CHECK-NEXT: for (int64_t [[IV:[ij][0-9]+]] = +// CHECK: if ([[IV]] == // CHECK: TMOV // CHECK: } // CHECK: } diff --git a/test/lit/pto/tassign_level3_loop_rebind.pto b/test/lit/pto/tassign_level3_loop_rebind.pto index a8d4a732b7..4246386de7 100644 --- a/test/lit/pto/tassign_level3_loop_rebind.pto +++ b/test/lit/pto/tassign_level3_loop_rebind.pto @@ -35,7 +35,8 @@ module { } // CHECK-LABEL: __global__ AICORE void tassign_loop_rebind() { -// CHECK: Tile [[T:v[0-9]+]]; +// CHECK: Tile [[T_STORAGE:v[0-9]+]]; +// CHECK: Tile [[T:v[0-9]+]] = [[T_STORAGE]]; // CHECK: for ( // CHECK: TASSIGN([[T]], // CHECK: TPRINT{{(<.*>)?}}([[T]]); diff --git a/test/lit/pto/tassign_level3_loop_rebind_gss.pto b/test/lit/pto/tassign_level3_loop_rebind_gss.pto index bc4d2c9480..7a5f5affa1 100644 --- a/test/lit/pto/tassign_level3_loop_rebind_gss.pto +++ b/test/lit/pto/tassign_level3_loop_rebind_gss.pto @@ -35,7 +35,8 @@ module { } // CHECK-LABEL: __global__ AICORE void tassign_loop_rebind() { -// CHECK: Tile [[T:v[0-9]+]]; +// CHECK: Tile [[T_STORAGE:v[0-9]+]]; +// CHECK: Tile [[T:v[0-9]+]] = [[T_STORAGE]]; // CHECK: for ( // CHECK: TASSIGN([[T]], // CHECK: TPRINT{{(<.*>)?}}([[T]]); diff --git a/test/lit/pto/tci_i16_emitc.pto b/test/lit/pto/tci_i16_emitc.pto index 4e255b7ad2..b4b072427c 100644 --- a/test/lit/pto/tci_i16_emitc.pto +++ b/test/lit/pto/tci_i16_emitc.pto @@ -4,9 +4,7 @@ module { func.func @tci_i16_kernel(%dst: memref<16xi16, #pto.address_space>) { %c0_i16 = arith.constant 0 : i16 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c16], strides: [%c16, %c1] {layout = #pto.layout} : memref<16xi16, #pto.address_space> to memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 16], strides: [16, 1] {layout = #pto.layout} : memref<16xi16, #pto.address_space> to memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c0_i16 : i16) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x16xi16, strided<[16, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} diff --git a/test/lit/pto/tci_ui16_emitc.pto b/test/lit/pto/tci_ui16_emitc.pto index e5ee180e0b..17c7eec7f1 100644 --- a/test/lit/pto/tci_ui16_emitc.pto +++ b/test/lit/pto/tci_ui16_emitc.pto @@ -5,9 +5,7 @@ module { %c7_i16 = arith.constant 7 : i16 %c7_ui16 = builtin.unrealized_conversion_cast %c7_i16 : i16 to ui16 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c16], strides: [%c16, %c1] {layout = #pto.layout} : memref<16xui16, #pto.address_space> to memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 16], strides: [16, 1] {layout = #pto.layout} : memref<16xui16, #pto.address_space> to memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c7_ui16 : ui16) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x16xui16, strided<[16, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} diff --git a/test/lit/pto/tci_ui32_emitc.pto b/test/lit/pto/tci_ui32_emitc.pto index b08c4c1181..e36175500f 100644 --- a/test/lit/pto/tci_ui32_emitc.pto +++ b/test/lit/pto/tci_ui32_emitc.pto @@ -5,9 +5,7 @@ module { %c5_i32 = arith.constant 5 : i32 %c5_ui32 = builtin.unrealized_conversion_cast %c5_i32 : i32 to ui32 %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %src = memref.reinterpret_cast %dst to offset: [%c0], sizes: [%c1, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref<32xui32, #pto.address_space> to memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space> + %src = memref.reinterpret_cast %dst to offset: [0], sizes: [1, 32], strides: [32, 1] {layout = #pto.layout} : memref<32xui32, #pto.address_space> to memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space> %tile = pto.alloc_tile : !pto.tile_buf pto.tci ins(%c5_ui32 : ui32) outs(%tile : !pto.tile_buf) pto.tstore ins(%tile : !pto.tile_buf) outs(%src : memref<1x32xui32, strided<[32, 1], offset: ?>, #pto.address_space>) {layout = #pto.layout, pto.inferred_layout = true} diff --git a/test/lit/pto/tprint_alloc_tile_no_rebind.pto b/test/lit/pto/tprint_alloc_tile_no_rebind.pto index c7ec3b2742..0ebf9332ed 100644 --- a/test/lit/pto/tprint_alloc_tile_no_rebind.pto +++ b/test/lit/pto/tprint_alloc_tile_no_rebind.pto @@ -14,8 +14,9 @@ module { // CHECK-LABEL: __global__ AICORE void print_kernel() { // CHECK: Tile [[TILE:v[0-9]+]]; -// CHECK: TASSIGN([[TILE]], [[ADDR:v[0-9]+]]); +// CHECK: Tile [[TILE_COPY:v[0-9]+]] = [[TILE]]; +// CHECK: TASSIGN([[TILE_COPY]], [[ADDR:v[0-9]+]]); // CHECK-NOT: TASSIGN( // CHECK-NOT: .data() // CHECK-NOT: reinterpret_cast -// CHECK: TPRINT{{(<.*>)?}}([[TILE]]); +// CHECK: TPRINT{{(<.*>)?}}([[TILE_COPY]]); diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto index 9ad19ec8f5..cb315e14f8 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto @@ -67,13 +67,15 @@ module { // CHECK-SAME: (__gm__ float* [[CUBE_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[CUBE_GM]], {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY_VAL:v[0-9]+]] = [[CUBE_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY_VAL]]); // CHECK: TSTORE -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY_VAL]]); // CHECK-LABEL: AICORE void vector_kernel // CHECK-SAME: (__gm__ float* [[VEC_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[VEC_GM]], {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY_VAL:v[0-9]+]] = [[VEC_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TLOAD // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto index 8562c12070..b4ec073938 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto @@ -84,23 +84,27 @@ module { // CHECK-LABEL: AICORE void cube_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY_VAL:v[0-9]+]] = [[CUBE_C2V_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY_VAL]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY_VAL]]); // CHECK-LABEL: AICORE void vector_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY_VAL:v[0-9]+]] = [[VEC_C2V_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK-LABEL: AICORE void vector_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY_VAL:v[0-9]+]] = [[VEC_V2C_ENTRY]]; +// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY_VAL]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY_VAL]]); // CHECK-LABEL: AICORE void cube_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); // CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY:v[0-9]+]](nullptr); +// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY_VAL:v[0-9]+]] = [[CUBE_V2C_ENTRY]]; // CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( diff --git a/test/lit/pto/treshape_static_valid_shape_emitc.pto b/test/lit/pto/treshape_static_valid_shape_emitc.pto index 9673ec0c17..701874e5f0 100644 --- a/test/lit/pto/treshape_static_valid_shape_emitc.pto +++ b/test/lit/pto/treshape_static_valid_shape_emitc.pto @@ -32,9 +32,11 @@ module attributes {pto.target_arch = "a5"} { } // CHECK-LABEL: AICORE void treshape_static_valid_shape_emitc() -// CHECK: Tile [[SRC:v[0-9]+]]; +// CHECK: Tile [[SRC_STORAGE:v[0-9]+]]; +// CHECK: Tile [[SRC:v[0-9]+]] = [[SRC_STORAGE]]; // CHECK: TASSIGN([[SRC]], -// CHECK: Tile [[RESHAPED:v[0-9]+]]; +// CHECK: Tile [[RESHAPED_STORAGE:v[0-9]+]]; +// CHECK-NEXT: Tile [[RESHAPED:v[0-9]+]] = [[RESHAPED_STORAGE]]; // CHECK-NEXT: TRESHAPE([[RESHAPED]], [[SRC]]); // CHECK-NOT: SetValidShape // CHECK-NOT: TRESHAPE diff --git a/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto index 7f28f9f3da..1a47f47ff5 100644 --- a/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto +++ b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto @@ -62,11 +62,11 @@ module { // CHECK-NEXT: pto.tsub ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 3 : i64, pto.last_use = array} // CHECK-NEXT: pto.tsubs ins(%{{.*}}, %{{.*}} : !pto.tile_buf, f32) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 4 : i64, pto.last_use = array} -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TEXPANDS__0" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADDS__0__1" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADD__0__1__0" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUB__0__1__1" -// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUBS__0__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TEXPANDS__0" +// MARKER: call_opaque "PTOAS__LAST_USE__TADDS__0__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TADD__0__1__0" +// MARKER: call_opaque "PTOAS__LAST_USE__TSUB__0__1__1" +// MARKER: call_opaque "PTOAS__LAST_USE__TSUBS__0__1" // MARKER-NOT: pto::last_use // CPP-LABEL: AICORE void mark_last_use_slot_mask_level2( diff --git a/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto b/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto index cf27d05033..ed90162b9c 100644 --- a/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto +++ b/test/tilelang_st/npu/a5/src/st/smoke/testcase/tmrgsort/tmrgsort.pto @@ -9,7 +9,7 @@ // TileLang ST kernels for pto.tmrgsort Format1: single list internal block sorting. // Input is divided into 4 blocks, each block sorted, then merged. // Output: interleaved (sorted_value, original_index) pairs. -// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand // to produce LLVM IR. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto index 76eeaaf182..b9bb701bff 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto @@ -9,7 +9,7 @@ // TileLang ST kernels for pto.tmrgsort Format1: single list internal block sorting. // Input is divided into 4 blocks, each block sorted, then merged. // Output: interleaved (sorted_value, original_index) pairs. -// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand // to produce LLVM IR. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 09a599ad7f..b56a03fdb9 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -735,7 +735,26 @@ def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo return None source = textwrap.dedent("".join(source_lines)) - module = ast.parse(source) + try: + module = ast.parse(source) + except SyntaxError: + try: + module = ast.parse(Path(path).read_text(encoding="utf-8")) + except (OSError, IOError, SyntaxError): + return None + code_line = getattr(py_fn, "__code__", None) + code_first_line = code_line.co_firstlineno if code_line is not None else start_line + for node in ast.walk(module): + if not isinstance(node, ast.FunctionDef) or node.name != py_fn.__name__: + continue + first_line = min( + [node.lineno] + [decorator.lineno for decorator in node.decorator_list] + ) + last_line = getattr(node, "end_lineno", node.lineno) + if first_line <= code_first_line <= last_line: + return _FunctionSourceInfo(path=path, start_line=1, function_def=node) + return None + for node in module.body: if isinstance(node, ast.FunctionDef) and node.name == py_fn.__name__: return _FunctionSourceInfo(path=path, start_line=start_line, function_def=node) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 8e73de48e7..d5ba2b1319 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -587,6 +587,11 @@ static LogicalResult emitSharedPreBackendSeamIR(ModuleOp module, return success(); } +static void printSharedPreBackendSeamIR(ModuleOp module) { + module->print(llvm::errs()); + llvm::errs() << "\n"; +} + static bool hasUnexpandedTileOps(ModuleOp module) { bool found = false; module.walk([&](Operation *op) { @@ -830,9 +835,199 @@ static void dropEmptyEmitCExpressions(Operation *rootOp) { expr.erase(); } +static void appendEmitCIntegerAttrLiteral(std::string &storage, + const APInt &value, bool isUnsigned) { + if (value.getBitWidth() == 0) { + storage.append("0"); + return; + } + if (value.getBitWidth() == 1) { + storage.append(value.getBoolValue() ? "true" : "false"); + return; + } + + SmallString<128> strValue; + value.toString(strValue, 10, !isUnsigned, false); + storage.append(strValue.data(), strValue.size()); +} + +static bool shouldPrintEmitCIntegerAttrAsUnsigned(IntegerAttr attr) { + auto intTy = dyn_cast(attr.getType()); + return intTy && intTy.getSignedness() == IntegerType::Unsigned; +} + +static std::string getEmitCIntegerAttrLiteral(IntegerAttr attr) { + std::string literal; + appendEmitCIntegerAttrLiteral(literal, attr.getValue(), + shouldPrintEmitCIntegerAttrAsUnsigned(attr)); + return literal; +} + +static std::optional +getEmitCDenseIntElementsAttrLiteral(DenseIntElementsAttr attr) { + auto tensorTy = dyn_cast(attr.getType()); + if (!tensorTy) + return std::nullopt; + + Type elementType = tensorTy.getElementType(); + bool isUnsigned = false; + if (auto intTy = dyn_cast(elementType)) { + isUnsigned = intTy.getSignedness() == IntegerType::Unsigned; + } else if (!isa(elementType)) { + return std::nullopt; + } + + std::string literal; + literal.push_back('{'); + bool first = true; + for (const APInt &value : attr) { + if (!first) + literal.append(", "); + first = false; + appendEmitCIntegerAttrLiteral(literal, value, isUnsigned); + } + literal.push_back('}'); + return literal; +} + +static Attribute normalizeEmitCPrintedAttrForCppEmission(MLIRContext *ctx, + Attribute attr) { + if (auto intAttr = dyn_cast(attr)) + return emitc::OpaqueAttr::get(ctx, getEmitCIntegerAttrLiteral(intAttr)); + + if (auto denseAttr = dyn_cast(attr)) { + if (std::optional literal = + getEmitCDenseIntElementsAttrLiteral(denseAttr)) + return emitc::OpaqueAttr::get(ctx, *literal); + } + + if (auto arrayAttr = dyn_cast(attr)) { + SmallVector normalized; + normalized.reserve(arrayAttr.size()); + bool changed = false; + for (Attribute element : arrayAttr) { + Attribute normalizedElement = + normalizeEmitCPrintedAttrForCppEmission(ctx, element); + changed |= normalizedElement != element; + normalized.push_back(normalizedElement); + } + if (changed) + return ArrayAttr::get(ctx, normalized); + } + + return attr; +} + +static IntegerAttr normalizeEmitCIndexPlaceholderAttr(MLIRContext *ctx, + IntegerAttr attr) { + const APInt &value = attr.getValue(); + int64_t index = value.getBitWidth() == 0 ? 0 : value.getSExtValue(); + return IntegerAttr::get(IndexType::get(ctx), APInt(64, index)); +} + +static ArrayAttr normalizeEmitCCallArgsForCppEmission(MLIRContext *ctx, + ArrayAttr args) { + SmallVector normalized; + normalized.reserve(args.size()); + bool changed = false; + + for (Attribute attr : args) { + if (auto intAttr = dyn_cast(attr)) { + if (isa(intAttr.getType())) { + Attribute normalizedAttr = + normalizeEmitCIndexPlaceholderAttr(ctx, intAttr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + continue; + } + + Attribute normalizedAttr = + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + continue; + } + + Attribute normalizedAttr = + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + } + + return changed ? ArrayAttr::get(ctx, normalized) : args; +} + +static ArrayAttr normalizeEmitCTemplateArgsForCppEmission(MLIRContext *ctx, + ArrayAttr args) { + SmallVector normalized; + normalized.reserve(args.size()); + bool changed = false; + + for (Attribute attr : args) { + Attribute normalizedAttr = + normalizeEmitCPrintedAttrForCppEmission(ctx, attr); + changed |= normalizedAttr != attr; + normalized.push_back(normalizedAttr); + } + + return changed ? ArrayAttr::get(ctx, normalized) : args; +} + +static void normalizeEmitCIntegerAttrsForCppEmission(Operation *rootOp) { + MLIRContext *ctx = rootOp->getContext(); + rootOp->walk([&](Operation *op) { + if (auto constant = dyn_cast(op)) { + Attribute value = constant.getValue(); + Attribute normalized = + normalizeEmitCPrintedAttrForCppEmission(ctx, value); + if (normalized != value) + constant.getProperties().setValue(normalized); + return; + } + + if (auto variable = dyn_cast(op)) { + Attribute value = variable.getValue(); + Attribute normalized = + normalizeEmitCPrintedAttrForCppEmission(ctx, value); + if (normalized != value) + variable.getProperties().setValue(normalized); + return; + } + + if (auto global = dyn_cast(op)) { + std::optional initialValue = global.getInitialValue(); + if (!initialValue) + return; + Attribute normalized = + normalizeEmitCPrintedAttrForCppEmission(ctx, *initialValue); + if (normalized != *initialValue) + global.getProperties().setInitialValue(normalized); + return; + } + + if (auto call = dyn_cast(op)) { + if (std::optional args = call.getArgs()) { + ArrayAttr normalized = normalizeEmitCCallArgsForCppEmission(ctx, *args); + if (normalized != *args) + call.getProperties().setArgs(normalized); + } + if (std::optional templateArgs = call.getTemplateArgs()) { + ArrayAttr normalized = + normalizeEmitCTemplateArgsForCppEmission(ctx, *templateArgs); + if (normalized != *templateArgs) + call.getProperties().setTemplateArgs(normalized); + } + return; + } + }); +} + static Attribute getDefaultEmitCVariableInitAttr(OpBuilder &builder, Type type) { - if (auto intTy = dyn_cast(type)) + if (auto intTy = dyn_cast(type)) { + if (intTy.getWidth() == 0) + return emitc::OpaqueAttr::get(builder.getContext(), "0"); return builder.getIntegerAttr(intTy, 0); + } if (isa(type)) return builder.getIndexAttr(0); if (auto floatTy = dyn_cast(type)) @@ -842,6 +1037,12 @@ static Attribute getDefaultEmitCVariableInitAttr(OpBuilder &builder, Type type) return Attribute{}; } +static Type getEmitCVariableStorageType(Type valueType) { + if (isa(valueType)) + return valueType; + return emitc::LValueType::get(valueType); +} + // FormExpressions may inline conditions into emitc.expression, but the C++ // emitter prints cf.br/cf.cond_br operands by variable name rather than by // recursively emitting an expression. Materialize such operands so CFG-based @@ -867,12 +1068,21 @@ static void materializeControlFlowOperands(Operation *rootOp) { if (!initAttr) continue; - Value tmp = - builder.create(op->getLoc(), value.getType(), - initAttr) - .getResult(); + Value tmp = builder + .create( + op->getLoc(), getEmitCVariableStorageType(value.getType()), + initAttr) + .getResult(); builder.create(op->getLoc(), tmp, value); - operand.set(tmp); + if (auto lvalueTy = dyn_cast(tmp.getType())) { + Value loaded = builder + .create(op->getLoc(), + lvalueTy.getValueType(), tmp) + .getResult(); + operand.set(loaded); + } else { + operand.set(tmp); + } } } } @@ -1958,10 +2168,8 @@ int mlir::pto::compilePTOASModule( return 1; } - if (ptoPrintSeamIR) { - module->print(llvm::errs()); - llvm::errs() << "\n"; - } + if (ptoPrintSeamIR) + printSharedPreBackendSeamIR(*module); if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) return 1; @@ -1971,21 +2179,37 @@ int mlir::pto::compilePTOASModule( context.getCANNVersionOrDefault()); } + if (failed(pm.run(*module))) { + llvm::errs() << "Error: Pass execution failed.\n"; + return 1; + } + + if (ptoPrintSeamIR) + printSharedPreBackendSeamIR(*module); + if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) + return 1; + + PassManager emitcPM(module->getContext()); + emitcPM.enableVerifier(); if (arch == "a3") { - pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); + emitcPM.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); } else { - pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A5)); + emitcPM.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A5)); } - pm.addPass(emitc::createFormExpressionsPass()); - pm.addPass(mlir::createCSEPass()); + emitcPM.addPass(emitc::createFormExpressionsPass()); + emitcPM.addPass(mlir::createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions( + emitcPM, "EmitC backend pipeline"))) + return 1; - if (failed(pm.run(*module))) { + if (failed(emitcPM.run(*module))) { llvm::errs() << "Error: Pass execution failed.\n"; return 1; } dropEmptyEmitCExpressions(module.get()); materializeControlFlowOperands(module.get()); + normalizeEmitCIntegerAttrsForCppEmission(module.get()); if (failed(reorderEmitCFunctions(module.get()))) { llvm::errs() << "Error: Failed to order emitted functions for C++ emission.\n"; return 1; diff --git a/tools/ptobc/src/mlir_encode.cpp b/tools/ptobc/src/mlir_encode.cpp index bfab87be73..a4cca9e356 100644 --- a/tools/ptobc/src/mlir_encode.cpp +++ b/tools/ptobc/src/mlir_encode.cpp @@ -121,6 +121,36 @@ static mlir::DictionaryAttr stripKnownImmediateAttrs( } } +static bool isFrontendPipeOpWithDefaultId(mlir::Operation &op) { + llvm::StringRef name = op.getName().getStringRef(); + return name == "pto.aic_initialize_pipe" || + name == "pto.aiv_initialize_pipe" || + name == "pto.talloc_to_aiv" || name == "pto.talloc_to_aic" || + name == "pto.tpush_to_aiv" || name == "pto.tpush_to_aic" || + name == "pto.tpop_from_aic" || name == "pto.tpop_from_aiv" || + name == "pto.tfree_from_aic" || name == "pto.tfree_from_aiv"; +} + +static mlir::DictionaryAttr +dropDefaultZeroPipeIdForV0Encoding(mlir::Operation &op, + mlir::DictionaryAttr dict) { + if (!dict || dict.empty() || !isFrontendPipeOpWithDefaultId(op)) + return dict; + + auto idAttr = dict.getAs("id"); + if (!idAttr || idAttr.getInt() != 0) + return dict; + + llvm::SmallVector keep; + keep.reserve(dict.size()); + for (auto attr : dict) { + if (attr.getName() == "id") + continue; + keep.push_back(attr); + } + return mlir::DictionaryAttr::get(op.getContext(), keep); +} + static uint64_t internAttr(PTOBCFile& f, mlir::DictionaryAttr dict) { if (!dict || dict.empty()) return 0; std::string s = printAttrDict(dict); @@ -468,6 +498,8 @@ void Encoder::encodeKnownOpImmediates( mask |= 0x1; if (at.getValidCol()) mask |= 0x2; + if (at.getAddr()) + mask |= 0x4; out.appendU8(mask); imms.push_back(mask); return; @@ -540,7 +572,8 @@ void Encoder::encodeKnownOpOperands( if (imms.empty()) throw std::runtime_error("optmask operands missing immediate"); uint8_t mask = uint8_t(imms.front()); - emitOperands(((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0)); + emitOperands(((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0) + + ((mask & 0x4) ? 1 : 0)); return; } default: @@ -557,6 +590,9 @@ void Encoder::encodeKnownOp(mlir::Operation &op, Buffer &out, out.appendU16LE(variantInfo.opcode); mlir::DictionaryAttr dict = op.getAttrDictionary(); dict = stripKnownImmediateAttrs(op.getContext(), dict, info); + // The assembly printer omits default pipe ids on transfer/free ops. Keep v0 + // byte roundtrips stable by using the same representation while encoding. + dict = dropDefaultZeroPipeIdForV0Encoding(op, dict); writeULEB128(internAttr(file, dict), out.bytes); if (info.has_variant_u8) diff --git a/tools/ptobc/src/ptobc_decode_print.cpp b/tools/ptobc/src/ptobc_decode_print.cpp index 3c13ae31b9..cdb0280902 100644 --- a/tools/ptobc/src/ptobc_decode_print.cpp +++ b/tools/ptobc/src/ptobc_decode_print.cpp @@ -487,7 +487,8 @@ readKnownOperandIds(BuildCtx &bc, Reader &r, uint16_t opcode, uint8_t variant, case 0x04: return reorderLegacyIndexedTscatter( readValueIds(r, ((imms.optMask & 0x1) ? 1 : 0) + - ((imms.optMask & 0x2) ? 1 : 0))); + ((imms.optMask & 0x2) ? 1 : 0) + + ((imms.optMask & 0x4) ? 1 : 0))); default: (void)bc; throw std::runtime_error("unknown operand_mode"); diff --git a/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto new file mode 100644 index 0000000000..b54687d0f4 --- /dev/null +++ b/tools/ptobc/testdata/recent_mx_ops_v0_roundtrip.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @recent_mx_ops_v0() { + %a = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %dst_acc = pto.alloc_tile : !pto.tile_buf + %dst_bias = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_acc : !pto.tile_buf) + pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_bias : !pto.tile_buf) + return + } +} diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index ecbf78fa89..af4254b1ff 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -42,18 +42,6 @@ module { pto.trsqrt ins(%src : !pto.tile_buf) outs(%rs_dst0 : !pto.tile_buf) pto.trsqrt ins(%src, %rs_tmp : !pto.tile_buf, !pto.tile_buf) outs(%rs_dst1 : !pto.tile_buf) pto.tpartmul ins(%part0, %part1 : !pto.tile_buf, !pto.tile_buf) outs(%partdst : !pto.tile_buf) - %a = pto.alloc_tile : !pto.tile_buf - %a_scale = pto.alloc_tile : !pto.tile_buf - %b = pto.alloc_tile : !pto.tile_buf - %b_scale = pto.alloc_tile : !pto.tile_buf - %c_in = pto.alloc_tile : !pto.tile_buf - %bias_mx = pto.alloc_tile : !pto.tile_buf - %dst_mx = pto.alloc_tile : !pto.tile_buf - %dst_mx_acc = pto.alloc_tile : !pto.tile_buf - %dst_mx_bias = pto.alloc_tile : !pto.tile_buf - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx : !pto.tile_buf) - pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_acc : !pto.tile_buf) - pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias_mx : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_bias : !pto.tile_buf) return } } diff --git a/tools/ptobc/tests/CMakeLists.txt b/tools/ptobc/tests/CMakeLists.txt index 2a2f78e3a4..0489351c08 100644 --- a/tools/ptobc/tests/CMakeLists.txt +++ b/tools/ptobc/tests/CMakeLists.txt @@ -117,6 +117,7 @@ add_test(NAME ptobc_tscatter_maskpattern_v0_encode COMMAND ${CMAKE_COMMAND} -E env PTOBC_BIN=$ TESTDATA_DIR=${PTObc_TESTDATA_DIR} + PYTHON_EXECUTABLE=${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/tscatter_maskpattern_v0_encode.sh ) diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index b8c302bf53..8517b4174b 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -22,14 +22,19 @@ if [[ -z "${TESTDATA_DIR}" ]]; then fi IN="${TESTDATA_DIR}/recent_ops_v0_roundtrip.pto" +MX_IN="${TESTDATA_DIR}/recent_mx_ops_v0_roundtrip.pto" OUT_DIR=${OUT_DIR:-"${PWD}/ptobc_recent_ops_out"} mkdir -p "${OUT_DIR}" BC="${OUT_DIR}/recent_ops_v0_roundtrip.ptobc" ROUNDTRIP="${OUT_DIR}/recent_ops_v0_roundtrip.roundtrip.pto" +MX_BC="${OUT_DIR}/recent_mx_ops_v0_roundtrip.ptobc" +MX_ROUNDTRIP="${OUT_DIR}/recent_mx_ops_v0_roundtrip.roundtrip.pto" "${PTOBC_BIN}" encode "${IN}" -o "${BC}" "${PTOBC_BIN}" decode "${BC}" -o "${ROUNDTRIP}" +"${PTOBC_BIN}" encode "${MX_IN}" -o "${MX_BC}" +"${PTOBC_BIN}" decode "${MX_BC}" -o "${MX_ROUNDTRIP}" grep -F "pto.subview " "${ROUNDTRIP}" >/dev/null grep -F "pto.tprint ins(" "${ROUNDTRIP}" >/dev/null @@ -41,6 +46,6 @@ grep -F "pto.trowexpandmul ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trsqrt ins(" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.trsqrt ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null grep -F "pto.tpartmul ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx.acc ins(" "${ROUNDTRIP}" >/dev/null -grep -F "pto.tgemv.mx.bias ins(" "${ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx ins(" "${MX_ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.acc ins(" "${MX_ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.bias ins(" "${MX_ROUNDTRIP}" >/dev/null diff --git a/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh b/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh index 482cd74eef..c85b1dadee 100755 --- a/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh +++ b/tools/ptobc/tests/tscatter_maskpattern_v0_encode.sh @@ -35,7 +35,19 @@ grep -F "pto.tscatter ins(" "${ROUNDTRIP}" >/dev/null grep -F "{maskPattern = #pto.mask_pattern}" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.tscatter ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null -python - <<'PY' "${BC}" +PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-} +if [[ -z "${PYTHON_EXECUTABLE}" ]]; then + if command -v python3 >/dev/null 2>&1; then + PYTHON_EXECUTABLE=python3 + elif command -v python >/dev/null 2>&1; then + PYTHON_EXECUTABLE=python + else + echo "error: neither PYTHON_EXECUTABLE nor python3/python is available" >&2 + exit 2 + fi +fi + +"${PYTHON_EXECUTABLE}" - <<'PY' "${BC}" from pathlib import Path import sys