diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 20862b519..659ae5bfa 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -125,37 +125,33 @@ jobs: fetch-depth: 1 persist-credentials: false - - name: Resolve LLVM directories + - name: Resolve simulator environment shell: bash run: | set -euo pipefail - echo "LLVM_ROOT=${RUNNER_TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" - echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert" >> "${GITHUB_ENV}" - echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" - - name: Resolve LLVM cache key - id: llvm-cache-key - shell: bash - run: | - set -euo pipefail - # Resolve to a Git object that ls-remote can handle: either a tag - # (LLVM_TAG) or a branch head (LLVM_REF). Only one is expected. - ref="${LLVM_TAG:-${LLVM_REF}}" - sha="$(git ls-remote "${LLVM_REPO}" "${ref}" | awk '{print $1}')" - if [[ -z "${sha}" ]]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${ref}" >&2 + detect_ascend_home() { + for d in \ + "${ASCEND_HOME_PATH:-}" \ + /usr/local/Ascend/cann \ + /usr/local/Ascend/cann-* \ + /usr/local/Ascend/ascend-toolkit/latest + do + [[ -n "${d}" && -d "${d}" ]] || continue + printf '%s\n' "${d}" + return 0 + done + return 1 + } + + ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" + if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then + echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 exit 1 fi - echo "sha=${sha}" >> "${GITHUB_OUTPUT}" - echo "key=llvm-build-${sha}-assert-v1" >> "${GITHUB_OUTPUT}" - - name: Restore LLVM build cache - id: llvm-cache - continue-on-error: true - uses: actions/cache/restore@v4 - with: - path: ${{ env.LLVM_DIR }} - key: ${{ steps.llvm-cache-key.outputs.key }} + echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" + echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - name: Ensure runner dependencies shell: bash @@ -201,12 +197,49 @@ jobs: python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi + - name: Resolve LLVM directories + shell: bash + run: | + set -euo pipefail + echo "LLVM_ROOT=${RUNNER_TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" + echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert" >> "${GITHUB_ENV}" + echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_BUILD_DIR=${GITHUB_WORKSPACE}/build-ptodsl" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_INSTALL_DIR=${GITHUB_WORKSPACE}/install-ptodsl" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PTOAS_BIN=${GITHUB_WORKSPACE}/build-ptodsl/tools/ptoas/ptoas" >> "${GITHUB_ENV}" + + - name: Resolve LLVM cache key + id: llvm-cache-key + shell: bash + run: | + set -euo pipefail + # Resolve to a Git object that ls-remote can handle: either a tag + # (LLVM_TAG) or a branch head (LLVM_REF). Only one is expected. + ref="${LLVM_TAG:-${LLVM_REF}}" + sha="$(git ls-remote "${LLVM_REPO}" "${ref}" | awk '{print $1}')" + if [[ -z "${sha}" ]]; then + echo "ERROR: failed to resolve ${LLVM_REPO} ${ref}" >&2 + exit 1 + fi + echo "sha=${sha}" >> "${GITHUB_OUTPUT}" + echo "key=llvm-build-${sha}-assert-v1" >> "${GITHUB_OUTPUT}" + + - name: Restore LLVM build cache + id: llvm-cache + continue-on-error: true + uses: actions/cache/restore@v4 + with: + path: ${{ env.LLVM_DIR }} + key: ${{ steps.llvm-cache-key.outputs.key }} + - name: Clean CI work dirs shell: bash run: | set -euo pipefail rm -rf "${GITHUB_WORKSPACE}/build" + rm -rf "${GITHUB_WORKSPACE}/build-ptodsl" rm -rf "${PTO_INSTALL_DIR}" + rm -rf "${PTO_DSL_ST_INSTALL_DIR}" rm -rf "${VPTO_SIM_WORKSPACE}" rm -rf "${TILELANG_DSL_WORKSPACE}" rm -rf "${PYPTO_WORKSPACE}" @@ -245,8 +278,8 @@ jobs: # Clean the build directory so that stale generated files from a # previous run (e.g. _smt_ops_gen.py left behind when the ref # changed) do not leak into the fresh build. - rm -rf llvm/build-assert - cmake -G Ninja -S llvm -B llvm/build-assert \ + rm -rf "${LLVM_DIR}" + cmake -G Ninja -S llvm -B "${LLVM_DIR}" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DBUILD_SHARED_LIBS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \ @@ -258,7 +291,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" - ninja -C llvm/build-assert + ninja -C "${LLVM_DIR}" - name: Save LLVM build cache if: steps.llvm-cache.outputs.cache-hit != 'true' @@ -281,34 +314,6 @@ jobs: PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" \ python3 -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" - - name: Resolve simulator environment - shell: bash - run: | - set -euo pipefail - - detect_ascend_home() { - for d in \ - "${ASCEND_HOME_PATH:-}" \ - /usr/local/Ascend/cann \ - /usr/local/Ascend/cann-* \ - /usr/local/Ascend/ascend-toolkit/latest - do - [[ -n "${d}" && -d "${d}" ]] || continue - printf '%s\n' "${d}" - return 0 - done - return 1 - } - - ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" - if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then - echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 - exit 1 - fi - - echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" - echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - - name: Checkout PyPTO uses: actions/checkout@v4 with: @@ -532,18 +537,277 @@ jobs: 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" fi + - name: Resolve PTODSL Python + shell: bash + run: | + set -euo pipefail + source "${ASCEND_HOME_PATH}/bin/setenv.bash" + + add_python_candidate() { + local candidate="$1" + local resolved + [[ -n "${candidate}" ]] || return 0 + if [[ "${candidate}" != */* ]]; then + resolved="$(command -v "${candidate}" 2>/dev/null || true)" + else + resolved="${candidate}" + fi + [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 + resolved="$(readlink -f "${resolved}")" + case ":${PYTHON_CANDIDATES[*]}:" in + *":${resolved}:"*) return 0 ;; + esac + PYTHON_CANDIDATES+=("${resolved}") + } + + has_torch_npu_packages() { + "$1" - <<'PY' + import importlib.util + + missing = [ + name for name in ("torch", "torch_npu") + if importlib.util.find_spec(name) is None + ] + raise SystemExit(1 if missing else 0) + PY + } + + probe_ptodsl_runtime_python() { + TORCH_DEVICE_BACKEND_AUTOLOAD=0 "$1" - <<'PY' + import sys + import numpy + import pybind11 + import yaml + import torch + import torch_npu # noqa: F401 + + print(sys.executable) + print(f"python {sys.version_info.major}.{sys.version_info.minor}") + print("torch", torch.__version__) + print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + print("numpy", numpy.__version__) + print("pybind11", pybind11.__version__) + PY + } + + missing_ptodsl_python_deps() { + "$1" - <<'PY' + import importlib.util + + deps = [ + ("setuptools", "setuptools"), + ("wheel", "wheel"), + ("numpy", "numpy"), + ("ml_dtypes", "ml-dtypes"), + ("yaml", "PyYAML"), + ] + + missing = [ + requirement for module_name, requirement in deps + if importlib.util.find_spec(module_name) is None + ] + + try: + import pybind11 + version = tuple(int(part) for part in pybind11.__version__.split(".")[:1]) + if version >= (3,): + missing.append("pybind11<3") + except Exception: + missing.append("pybind11<3") + + print(" ".join(missing)) + PY + } + + PYTHON_CANDIDATES=() + if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then + add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" + else + add_python_candidate python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python + add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 + add_python_candidate /home/zhoujiaming/miniconda3/bin/python + shopt -s nullglob + for candidate in \ + /home/*/miniconda3/envs/*/bin/python \ + /home/*/anaconda3/envs/*/bin/python \ + /opt/conda/envs/*/bin/python + do + add_python_candidate "${candidate}" + done + shopt -u nullglob + fi + + SELECTED_BASE_PYTHON="" + for candidate in "${PYTHON_CANDIDATES[@]}"; do + echo "Probing PTODSL base Python package presence: ${candidate}" + if has_torch_npu_packages "${candidate}"; then + SELECTED_BASE_PYTHON="${candidate}" + break + fi + done + + if [[ -z "${SELECTED_BASE_PYTHON}" ]]; then + echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 + echo "ERROR: this workflow intentionally does not install torch or torch_npu on every run." >&2 + echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 + exit 1 + fi + + PYTHON_ABI_TAG="$("${SELECTED_BASE_PYTHON}" - <<'PY' + import sys + print(f"py{sys.version_info.major}{sys.version_info.minor}") + PY + )" + BASE_HASH="$(printf '%s' "${SELECTED_BASE_PYTHON}" | sha256sum | cut -c1-12)" + PYTHON_TAG="${PYTHON_ABI_TAG}-${BASE_HASH}" + PTODSL_PYTHON_ROOT="${RUNNER_TOOL_CACHE}/ptodsl-python/${PYTHON_TAG}" + if [[ ! -x "${PTODSL_PYTHON_ROOT}/bin/python" ]]; then + rm -rf "${PTODSL_PYTHON_ROOT}" + "${SELECTED_BASE_PYTHON}" -m venv --system-site-packages "${PTODSL_PYTHON_ROOT}" + fi + + PTODSL_PYTHON="${PTODSL_PYTHON_ROOT}/bin/python" + if ! "${PTODSL_PYTHON}" -m pip --version >/dev/null 2>&1; then + "${PTODSL_PYTHON}" -m ensurepip --upgrade + fi + missing_deps="$(missing_ptodsl_python_deps "${PTODSL_PYTHON}")" + if [[ -n "${missing_deps}" ]]; then + "${PTODSL_PYTHON}" -m pip install ${missing_deps} + fi + + probe_ptodsl_runtime_python "${PTODSL_PYTHON}" + + PTO_DSL_ST_LLVM_DIR="${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${PYTHON_TAG}" + PTO_DSL_ST_MLIR_PYTHONPATH="${PTO_DSL_ST_LLVM_DIR}/tools/mlir/python_packages/mlir_core" + PTO_DSL_ST_PYTHON_SITE="$( + PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" "${PTODSL_PYTHON}" - <<'PY' + import os + import sysconfig + + prefix = os.environ["PTO_INSTALL_DIR"] + print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) + PY + )" + + echo "PTO_DSL_ST_BASE_PYTHON=${SELECTED_BASE_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_BIN=${PTODSL_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_TAG=${PYTHON_TAG}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_LLVM_DIR=${PTO_DSL_ST_LLVM_DIR}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_MLIR_PYTHONPATH=${PTO_DSL_ST_MLIR_PYTHONPATH}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_SITE=${PTO_DSL_ST_PYTHON_SITE}" >> "${GITHUB_ENV}" + + - name: Restore PTODSL LLVM build cache + id: ptodsl-llvm-cache + continue-on-error: true + uses: actions/cache/restore@v4 + with: + path: ${{ env.PTO_DSL_ST_LLVM_DIR }} + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + + - name: Build PTODSL LLVM/MLIR + if: steps.ptodsl-llvm-cache.outputs.cache-hit != 'true' + shell: bash + run: | + set -euo pipefail + cd "${LLVM_ROOT}" + export CC=gcc + export CXX=g++ + rm -rf "${PTO_DSL_ST_LLVM_DIR}" + cmake -G Ninja -S llvm -B "${PTO_DSL_ST_LLVM_DIR}" \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DPython_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DPython3_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ + -DPython_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ + -DPython3_FIND_STRATEGY=LOCATION \ + -DPython_FIND_STRATEGY=LOCATION \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DPython_FIND_VIRTUALENV=ONLY \ + -Dpybind11_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + + ninja -C "${PTO_DSL_ST_LLVM_DIR}" + + - name: Save PTODSL LLVM build cache + if: steps.ptodsl-llvm-cache.outputs.cache-hit != 'true' + continue-on-error: true + uses: actions/cache/save@v4 + with: + path: ${{ env.PTO_DSL_ST_LLVM_DIR }} + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + + - name: Build PTODSL PTOAS + shell: bash + run: | + set -euo pipefail + rm -rf "${PTO_DSL_ST_BUILD_DIR}" "${PTO_DSL_ST_INSTALL_DIR}" + export CC=gcc + export CXX=g++ + LLVM_BUILD_DIR="${PTO_DSL_ST_LLVM_DIR}" \ + PTO_BUILD_DIR="${PTO_DSL_ST_BUILD_DIR}" \ + PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" \ + "${PTO_DSL_ST_PYTHON_BIN}" -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_DSL_ST_INSTALL_DIR}" + - name: Run PTODSL DSL ST CI shell: bash run: | set -euo pipefail mkdir -p "${TILELANG_DSL_WORKSPACE}" - export LLVM_BUILD_DIR="${LLVM_DIR}" - export PYTHON_BIN="python3" + export LLVM_BUILD_DIR="${PTO_DSL_ST_LLVM_DIR}" + export MLIR_PYTHON_ROOT="${PTO_DSL_ST_MLIR_PYTHONPATH}" + export PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" + export PTO_PYTHON_BUILD_ROOT="${PTO_DSL_ST_BUILD_DIR}/python" + export PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" + export PTO_PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" + export PTOAS_PYTHON_SITE="${PTO_DSL_ST_PYTHON_SITE}" + export TORCH_DEVICE_BACKEND_AUTOLOAD=0 + source "${ASCEND_HOME_PATH}/bin/setenv.bash" + + probe_ptodsl_python() { + PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_DSL_ST_INSTALL_DIR}:${PTOAS_PYTHON_SITE}:${PTO_DSL_ST_MLIR_PYTHONPATH}:${PTO_DSL_ST_BUILD_DIR}/python:${PYTHONPATH:-}" \ + "${PTO_DSL_ST_PYTHON_BIN}" - <<'PY' + import sys + import numpy + import torch + import torch_npu # noqa: F401 + from ptodsl import pto # noqa: F401 + from mlir.dialects import pto as _pto # noqa: F401 + + print(sys.executable) + print("torch", torch.__version__) + print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + print("numpy", numpy.__version__) + PY + } + + PROBE_LOG="${TILELANG_DSL_WORKSPACE}/ptodsl-python-probe.log" + : > "${PROBE_LOG}" + echo "Probing PTODSL DSL ST Python: ${PTO_DSL_ST_PYTHON_BIN}" | tee -a "${PROBE_LOG}" + if ! probe_ptodsl_python >> "${PROBE_LOG}" 2>&1; then + cat "${PROBE_LOG}" >&2 + echo "ERROR: selected PTODSL Python cannot import torch, torch_npu, ptodsl, and MLIR PTO bindings." >&2 + exit 1 + fi + + cat "${PROBE_LOG}" + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ - PTOAS_BIN="${PTOAS_BIN}" \ + PTOAS_BIN="${PTO_DSL_ST_PTOAS_BIN}" \ scripts/sim_dsl.sh test/dsl-st \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st.log" + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTO_DSL_ST_PTOAS_BIN}" \ + scripts/sim_dsl.sh test/dsl-st/npu_a5 \ + 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st-npu-a5.log" + - name: Upload TileLang DSL logs if: always() uses: actions/upload-artifact@v4 @@ -552,6 +816,8 @@ jobs: path: | ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-dsl-st.log + ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-dsl-st-npu-a5.log + ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-python-probe.log if-no-files-found: warn - name: Run TileLang DSL unit tests diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3de31a89b..0a69cfaa1 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -73,6 +73,8 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); +std::unique_ptr +createPTOViewToMemrefPass(const PTOViewToMemrefOptions &options); std::unique_ptr createPTOValidateIntToPtrUsesPass(); std::unique_ptr createPTOMaterializeTileHandlesPass(); std::unique_ptr createPTOResolveBufferSelectPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 63b06b6db..aae679b82 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -507,13 +507,13 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu - pto.tile_valid_cols → same as above for v_col tensor_view family: - - pto.tensor_view_addr → traces through unrealized_conversion_cast → - subview → reinterpret_cast, then folds to the base memref or to - pto.castptr/pto.addptr on the base memref - - pto.get_tensor_view_dim → folded to arith.constant for static subview - sizes, or to the subview size SSA operand for dynamic dims - - pto.get_tensor_view_stride → folded to the reinterpret_cast stride - operand, multiplied by the subview stride when needed + - pto.tensor_view_addr → traces through + unrealized_conversion_cast → subview → reinterpret_cast, then folds to + the base memref or to pto.castptr/pto.addptr on the base pointer + - pto.get_tensor_view_dim → folded to arith.constant for static view sizes, + or to the source size SSA operand for dynamic dims + - pto.get_tensor_view_stride → folded to the lowered reinterpret_cast + stride, multiplied by the subview stride when needed Dead unrealized_conversion_cast, memref.subview, and memref.reinterpret_cast ops exposed by folding are cleaned up after the @@ -646,6 +646,10 @@ def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> { }]; let constructor = "mlir::pto::createPTOViewToMemrefPass()"; + let options = [ + Option<"viewOnly", "view-only", "bool", /*default=*/"false", + "Only rerun structured tensor_view lowering without rewriting tile or compute surfaces"> + ]; let dependentDialects = [ "mlir::pto::PTODialect", diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 087d8d62c..06a49b043 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -82,10 +82,10 @@ namespace { // Four kinds of operands: // Tile — from TileBufType. dtype + shape + memorySpace + config // all participate in the specialization key (SpecKey). -// View — from MemRefType (lowered PartitionTensorViewType). The element -// dtype and optional explicit layout participate in SpecKey; -// shape/strides/memorySpace remain JSON-only metadata for Python -// constraint checking and must not perturb C++ codegen caching. +// View — from MemRefType (lowered TensorView/PartitionTensorView). +// dtype, shape, strides, memorySpace, and optional explicit layout +// participate in SpecKey because they affect template selection and +// generated DMA parameters for tload/tstore. // Vector — from builtin VectorType. The element dtype and vector shape // participate in SpecKey so helper-side schema filtering can // distinguish auxiliary vector operands such as tmrgsort's @@ -107,7 +107,7 @@ struct OperandTypeInfo { int32_t fractal = 0; uint64_t pad = 0; - // --- View-only (MemRefType) — for JSON / constraint checking only --- + // --- View-only --- SmallVector viewShape; SmallVector viewStrides; std::string viewMemorySpace; // "gm" or "ub" @@ -133,8 +133,8 @@ struct OperandTypeInfo { return vectorShape == rhs.vectorShape; if (kind == OperandKind::Scalar) return scalarValue == rhs.scalarValue; - // View: dtype + explicit layout are sufficient for template caching. - return viewLayout == rhs.viewLayout; + return viewShape == rhs.viewShape && viewStrides == rhs.viewStrides && + viewMemorySpace == rhs.viewMemorySpace && viewLayout == rhs.viewLayout; } }; @@ -178,7 +178,11 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { h = llvm::hash_combine(h, *op.scalarValue); } if (op.kind == OperandKind::View) { - h = llvm::hash_combine(h, op.viewLayout.has_value()); + h = llvm::hash_combine(h, op.viewMemorySpace, op.viewLayout.has_value()); + for (int64_t d : op.viewShape) + h = llvm::hash_combine(h, d); + for (int64_t d : op.viewStrides) + h = llvm::hash_combine(h, d); if (op.viewLayout) h = llvm::hash_combine(h, static_cast(*op.viewLayout)); } @@ -633,7 +637,7 @@ static std::optional buildOperandTypeInfo(Value value) { return info; } - // View operand — from MemRefType (lowered PartitionTensorViewType). + // View operand — from MemRefType (lowered TensorView / PartitionTensorView). if (auto mrTy = dyn_cast(ty)) { OperandTypeInfo info; info.kind = OperandKind::View; @@ -843,6 +847,11 @@ static std::string buildUniqueFunctionBaseName(const SpecKey &key) { uniqueName += "_fr" + std::to_string(op.fractal); uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false); } else if (op.kind == OperandKind::View) { + for (int64_t d : op.viewShape) + uniqueName += "_s" + std::to_string(d); + for (int64_t d : op.viewStrides) + uniqueName += "_st" + std::to_string(d); + uniqueName += "_ms_" + op.viewMemorySpace; if (op.viewLayout) uniqueName += "_vl_" + stringifyLayout(*op.viewLayout).str(); } else if (op.kind == OperandKind::Vector) { @@ -873,6 +882,39 @@ static std::string buildContextAttrsJson(const SpecKey &key) { return json; } +static bool isViewLikeType(Type type) { + return isa(type); +} + +static void specializeTemplateEntryArgumentTypes(func::FuncOp fn, + Operation *tileOp) { + if (!fn || fn.isExternal()) + return; + + FunctionType fnTy = fn.getFunctionType(); + SmallVector inputs(fnTy.getInputs().begin(), fnTy.getInputs().end()); + bool changed = false; + unsigned operandCount = std::min(tileOp->getNumOperands(), + inputs.size()); + for (unsigned i = 0; i < operandCount; ++i) { + Type callerTy = tileOp->getOperand(i).getType(); + Type calleeTy = inputs[i]; + if (callerTy == calleeTy) + continue; + if (!isViewLikeType(callerTy) || !isViewLikeType(calleeTy)) + continue; + inputs[i] = callerTy; + fn.getArgument(i).setType(callerTy); + changed = true; + } + + if (!changed) + return; + + fn.setFunctionType(FunctionType::get(fn.getContext(), inputs, + fnTy.getResults())); +} + // ============================================================================ // Invoke Python DSL daemon RPC to generate a specialized template function. // ============================================================================ @@ -1030,6 +1072,7 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, } auto cloned = clonedFuncs.front(); + specializeTemplateEntryArgumentTypes(cloned, tileOp); if (!cloned->hasAttr("pto.tilelang.instance")) { llvm::errs() << "ExpandTileOp: warning: daemon output function @" << cloned.getSymName() @@ -1229,6 +1272,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, } auto cloned = clonedFuncs.front(); + specializeTemplateEntryArgumentTypes(cloned, tileOp); // The pto.tilelang.instance attribute should already be set by the // TileLang DSL frontend in the generated MLIR. Verify it exists. if (!cloned->hasAttr("pto.tilelang.instance")) { diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index 37cdc785d..d6ae21050 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -25,11 +25,10 @@ // For tile_buf intrinsics, the active VPTO path folds against materialized tile // handles produced by the shared tile-handle bridge (`pto.alloc_tile` or // `pto.materialize_tile`). -// For tensor_view intrinsics, the pass traces through the full -// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast -// chain to fold directly to constants or SSA operands from the -// reinterpret_cast, without generating intermediate memref.dim / -// memref.extract_strided_metadata ops. +// For tensor_view intrinsics, the pass traces through the lowered +// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast chain +// to fold directly to constants or SSA operands, without generating +// intermediate memref.dim / memref.extract_strided_metadata ops. // //===----------------------------------------------------------------------===// @@ -42,6 +41,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -90,6 +90,19 @@ static void eraseDeadAllocTileOps(func::FuncOp func) { alloc.erase(); } +static bool isDeadPTODSLSubkernelHelper(func::FuncOp func) { + if (!func->hasAttr("pto.ptodsl.subkernel_helper")) + return false; + + auto module = func->getParentOfType(); + if (!module) + return false; + + SymbolTable symbolTable(module); + auto uses = symbolTable.getSymbolUses(func, module); + return uses && uses->empty(); +} + struct TileHandleInfo { Value sourceMemref; Value addr; @@ -210,7 +223,6 @@ static MemRefType getCanonicalMemRefTypeForTileBuf(pto::TileBufType tileTy) { } struct ViewChain { - UnrealizedConversionCastOp cast; memref::SubViewOp subview; memref::ReinterpretCastOp reinterpretCast; Value baseMemref; @@ -218,28 +230,22 @@ struct ViewChain { static std::optional traceViewChain(Value tensorView, Operation *user) { - Value memrefVal; - UnrealizedConversionCastOp castOp; - - if (isa(tensorView.getType())) { - memrefVal = tensorView; - } else { - castOp = tensorView.getDefiningOp(); - if (!castOp || castOp.getNumOperands() != 1) { - user->emitError( - "FoldTileBufIntrinsics: expected tensor_view to be defined by a " - "single-operand builtin.unrealized_conversion_cast"); - return std::nullopt; - } - memrefVal = castOp.getOperand(0); - if (!isa(memrefVal.getType())) { - user->emitError( - "FoldTileBufIntrinsics: expected cast operand to be a memref, got ") - << memrefVal.getType(); - return std::nullopt; - } + Value view = tensorView; + + if (auto cast = view.getDefiningOp()) { + if (cast.getNumOperands() == 1 && cast.getNumResults() == 1) + view = cast.getOperand(0); + } + + if (!isa(view.getType())) { + user->emitError("FoldTileBufIntrinsics: expected tensor_view to be lowered " + "to a memref.subview chain before folding, got ") + << (view.getDefiningOp() ? view.getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; } + Value memrefVal = view; auto subviewOp = memrefVal.getDefiningOp(); if (!subviewOp) { user->emitError("FoldTileBufIntrinsics: expected memref to be defined by " @@ -261,7 +267,11 @@ static std::optional traceViewChain(Value tensorView, return std::nullopt; } - return ViewChain{castOp, subviewOp, rcOp, rcOp.getSource()}; + ViewChain chain; + chain.subview = subviewOp; + chain.reinterpretCast = rcOp; + chain.baseMemref = rcOp.getSource(); + return chain; } static bool getConstIndexValue(Value v, int64_t &out) { @@ -380,12 +390,13 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - // Leftover TileLang template instances (private, uncalled after - // PTOInlineLibCall) still contain pto.tile_buf_addr / tile_valid_* - // ops on tile_buf function arguments — they have no materialized tile - // handle anchor to fold against and will be removed by later DCE. Skip - // them. - if (func->hasAttr("pto.tilelang.instance")) + // Leftover TileLang template instances and already-inlined PTODSL + // subkernel helpers may still contain structured-view intrinsics on + // function arguments. Those formal arguments have no materialized + // call-site handle to fold against; the live caller body has already been + // inlined and folded separately. + if (func->hasAttr("pto.tilelang.instance") || + isDeadPTODSLSubkernelHelper(func)) return; SmallVector addrOps; @@ -667,14 +678,16 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - Value linearOffset = + Value linearOffset; + Value basePtr; + linearOffset = computeLinearOffset(builder, addrOp.getLoc(), chain->reinterpretCast.getMixedOffsets(), chain->subview.getMixedOffsets(), chain->reinterpretCast.getMixedStrides()); - - Value basePtr = builder.create( + basePtr = builder.create( addrOp.getLoc(), resultPtrType, chain->baseMemref); + Value replacement = linearOffset ? builder.create(addrOp.getLoc(), resultPtrType, diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp index c278bd196..6ab5c4f60 100644 --- a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -347,6 +347,22 @@ static void eraseDeadMatchingPrivateFuncs(ModuleOp module, } } +static void eraseDeadPTODSLSubkernelHelpers(ModuleOp module) { + for (ModuleOp funcModule : collectFuncModules(module)) { + SymbolTable symbolTable(funcModule); + SmallVector deadFuncs; + for (func::FuncOp func : funcModule.getOps()) { + if (!isPTODSLSubkernelHelperFunc(func)) + continue; + auto uses = symbolTable.getSymbolUses(func, funcModule); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + } +} + struct PTOInlineBackendHelpersPass : public pto::impl::PTOInlineBackendHelpersBase< PTOInlineBackendHelpersPass> { @@ -371,7 +387,7 @@ struct PTOInlineBackendHelpersPass << " call(s)\n"; } - eraseDeadMatchingPrivateFuncs(module, isInlineableBackendHelperFunc); + eraseDeadPTODSLSubkernelHelpers(module); } }; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index e86a0a3ad..e7999eaa8 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1599,12 +1599,18 @@ static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx struct PTOViewToMemrefPass : public mlir::pto::impl::PTOViewToMemrefBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) + using mlir::pto::impl::PTOViewToMemrefBase< + PTOViewToMemrefPass>::PTOViewToMemrefBase; void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); for (auto func : mod.getOps()) { + if (viewOnly) { + if (func.isExternal()) + continue; + } else { // ------------------------------------------------------------------ // Stage 0: ensure inttoptr values remain scalar-load/store only. // ------------------------------------------------------------------ @@ -2097,6 +2103,7 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } + } // ------------------------------------------------------------------ // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast @@ -2171,25 +2178,27 @@ struct PTOViewToMemrefPass rewriter.replaceOp(op, rc.getResult()); } - // ------------------------------------------------------------------ - // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim - // ------------------------------------------------------------------ - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + if (!viewOnly) { + // ------------------------------------------------------------------ + // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim + // ------------------------------------------------------------------ + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; // leave it to later passes if it hasn't been lowered yet + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet - Value dimIdx = op.getDimIndex(); - Value dim = rewriter.create(loc, view, dimIdx); - rewriter.replaceOp(op, dim); + Value dimIdx = op.getDimIndex(); + Value dim = rewriter.create(loc, view, dimIdx); + rewriter.replaceOp(op, dim); + } } // ------------------------------------------------------------------ @@ -2220,6 +2229,9 @@ struct PTOViewToMemrefPass return; } + if (viewOnly) + continue; + // ------------------------------------------------------------------ // Stage 1.5: Lower pto.get_tensor_view_stride -> strided memref metadata // ------------------------------------------------------------------ @@ -2378,7 +2390,11 @@ struct PTOViewToMemrefPass } } - // Clean up: addptr should be folded into make_tensor_view. + // Clean up dead addptr after folding the view/scalar patterns above. + // Live addptr users are legal low-level pointer arithmetic on the VPTO + // path (for example helper-local DMA pointer bumps that appear after + // ExpandTileOp + inline). Leave them in place so a second + // PTOViewToMemref run stays idempotent over already-lowered helper IR. DefaultInlineVector addPtrs; func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); bool changed = true; @@ -2394,13 +2410,9 @@ struct PTOViewToMemrefPass } } } - for (auto *op : addPtrs) { - if (!op) - continue; - op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); - signalPassFailure(); - return; - } + + if (viewOnly) + continue; // ------------------------------------------------------------------ // Stage 3: Rewrite Compute Ops @@ -4307,5 +4319,10 @@ std::unique_ptr createPTOViewToMemrefPass() { return std::make_unique(); } +std::unique_ptr +createPTOViewToMemrefPass(const PTOViewToMemrefOptions &options) { + return std::make_unique(options); +} + } // namespace pto } // namespace mlir diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 12eba5ce5..7435052cd 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -208,8 +208,10 @@ def my_kernel( barriers by hand, and work with raw pointers — useful when you need to hand-tune instruction schedules or overlap DMA with compute. -`mode` only affects what you can write inside the function body. It doesn't -change how you compile or launch the kernel. +For native launch builds, `mode` also selects the default PTOAS build policy: +`mode="auto"` keeps the PTOAS default build level and enables sync insertion, +while `mode="explicit"` uses `--pto-level=level3` and leaves synchronization +under user control by default. #### `backend`: VPTO vs EmitC diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 7bb02647c..3c956e841 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -50,8 +50,10 @@ The **`backend`** parameter selects the compilation target: rejected at decoration time with an actionable diagnostic. The **`mode`** parameter selects the programming model within the kernel body -(see Section 3.4). `mode` only affects what you can write inside the function — -it doesn't change how you compile or launch the kernel. +(see Section 3.4). For native launch builds, `mode="auto"` keeps PTOAS default +build level and enables sync insertion by default, while `mode="explicit"` uses +`--pto-level=level3` and disables sync insertion by default. This matches the +manual-address, user-managed staging contract of explicit kernels. `@pto.jit` owns compilation (tracing + lowering), caching, and — for `entry=True` — runtime launch binding. The compute-unit decorators diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index de1f303d5..6ce0fff1e 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -272,6 +272,15 @@ def illegal_inline_subkernel_placement_error(role: str, outer_role: str | None) ) +def subkernel_kernel_kind_mismatch_error(role: str, kernel_kind: str) -> RuntimeError: + """Return one diagnostic for mixing explicit @pto.jit kernel kind with the opposite subkernel kind.""" + return RuntimeError( + f"@pto.{role} cannot be lowered inside an explicit @pto.jit(kernel_kind={kernel_kind!r}) " + "module. Remove the explicit kernel_kind so PTOAS can split cube/vector sections, " + "or keep subkernel scopes in the same physical kind." + ) + + def inline_subkernel_value_escape_error(role: str, type_text: str) -> RuntimeError: """Return one diagnostic for outlined inline-scope values escaping their helper boundary.""" return RuntimeError( @@ -428,6 +437,7 @@ def unsupported_public_surface_error(name: str) -> AttributeError: "subkernel_host_tensor_boundary_error", "subkernel_illegal_annotation_error", "subkernel_illegal_parameter_kind_error", + "subkernel_kernel_kind_mismatch_error", "subkernel_missing_annotation_error", "subkernel_signature_boundary_error", "tile_row_alignment_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index 5c23bc97f..1d052fd20 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -32,6 +32,15 @@ _MODULE_ATTRS = ("pto.target_arch",) _SUPPORTED_FRONTEND_OPTION_KEYS = {"ast_rewrite", "rewrite_part", "dump_rewritten_source"} _SUPPORTED_REWRITE_PARTS = {"control_flow"} +_DEFAULT_KERNEL_KIND = "vector" + + +class _DefaultKernelKindSentinel: + def __repr__(self) -> str: + return repr(_DEFAULT_KERNEL_KIND) + + +_DEFAULT_KERNEL_KIND_SENTINEL = _DefaultKernelKindSentinel() def _normalize_mode(mode: str, *, fn=None) -> str: @@ -162,7 +171,7 @@ def jit( name=None, *, target: str = "a5", - kernel_kind: str = "vector", + kernel_kind: str = _DEFAULT_KERNEL_KIND_SENTINEL, backend: str = "vpto", entry: bool = True, mode: str = "auto", @@ -177,10 +186,10 @@ def jit( ---------- name: IR function name (defaults to the Python function name). target: Target architecture string, e.g. ``"a5"``. - kernel_kind: authored default physical kind, used for native build selection - and VPTO authoring intent. PTODSL now expresses physical regions - through ``pto.section.vector/cube`` instead of child-module - ``pto.kernel_kind`` attributes. + kernel_kind: optional authored physical kind, used for native build selection + and explicit single-kind VPTO authoring intent. When omitted, + PTODSL keeps the historical vector default while allowing + subkernel sections to express mixed cube/vector regions. backend: ``"vpto"`` or ``"emitc"`` – records the intended backend. entry: ``True`` for launchable kernel entries, ``False`` for helpers. mode: ``"auto"`` or ``"explicit"`` – feeds child compile policy. @@ -219,12 +228,15 @@ def decorator(fn): source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) except (OSError, TypeError): source_file = None + kernel_kind_explicit = kernel_kind is not _DEFAULT_KERNEL_KIND_SENTINEL + effective_kernel_kind = kernel_kind if kernel_kind_explicit else _DEFAULT_KERNEL_KIND compiler = KernelCompiler( fn.__name__, KernelModuleSpec( function_name=fn_name, target_arch=target, - kernel_kind=kernel_kind, + kernel_kind=effective_kernel_kind, + kernel_kind_explicit=kernel_kind_explicit, backend=normalized_backend, entry=entry, mode=normalized_mode, @@ -287,6 +299,10 @@ def __ptodsl_cache_signature__(self): self._compiler._kernel_identity, module_spec.function_name, module_spec.entry, + module_spec.backend, + module_spec.mode, + module_spec.kernel_kind, + module_spec.kernel_kind_explicit, ) def _build_default_module(self): diff --git a/ptodsl/ptodsl/_runtime/cache.py b/ptodsl/ptodsl/_runtime/cache.py index 6231c16b0..955239940 100644 --- a/ptodsl/ptodsl/_runtime/cache.py +++ b/ptodsl/ptodsl/_runtime/cache.py @@ -67,6 +67,7 @@ def write_manifest( launch_symbol: str, mlir_digest: str, launch_cpp_digest: str, + compile_config_digest: str, link_config_digest: str, ) -> None: artifacts.cache_dir.mkdir(parents=True, exist_ok=True) @@ -76,6 +77,7 @@ def write_manifest( "shared_library": str(artifacts.shared_library), "mlir_digest": mlir_digest, "launch_cpp_digest": launch_cpp_digest, + "compile_config_digest": compile_config_digest, "link_config_digest": link_config_digest, } artifacts.manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8") @@ -90,6 +92,7 @@ def is_native_build_current( *, mlir_text: str, launch_cpp_text: str, + compile_config_text: str, link_config_text: str, ) -> bool: required = ( @@ -110,6 +113,7 @@ def is_native_build_current( return ( manifest.get("mlir_digest") == _content_digest(mlir_text) and manifest.get("launch_cpp_digest") == _content_digest(launch_cpp_text) + and manifest.get("compile_config_digest") == _content_digest(compile_config_text) and manifest.get("link_config_digest") == _content_digest(link_config_text) ) diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index 777fa5420..fe1a51741 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -45,12 +45,15 @@ def _run_ptoas( *, target_arch: str, insert_sync: bool | None = None, + pto_level: str | None = None, ) -> None: ptoas = resolve_ptoas_binary() cmd = [ str(ptoas), f"--pto-arch={target_arch}", ] + if pto_level is not None: + cmd.append(f"--pto-level={pto_level}") if insert_sync is True: cmd.append("--enable-insert-sync") cmd.extend([ @@ -70,6 +73,23 @@ def _effective_insert_sync(*, mode: str, insert_sync: bool | None) -> bool: return mode != "explicit" +def _effective_pto_level(*, mode: str) -> str | None: + return "level3" if mode == "explicit" else None + + +def _compile_config_text(*, module_spec, effective_insert_sync: bool, effective_pto_level: str | None) -> str: + return "\n".join( + [ + f"target_arch={module_spec.target_arch}", + f"kernel_kind={module_spec.kernel_kind}", + f"mode={module_spec.mode}", + f"insert_sync={effective_insert_sync}", + f"pto_level={effective_pto_level}", + "enable_tile_op_expand=True", + ] + ) + + def _host_compile_flags() -> list[str]: return common_include_flags() + [ "-std=gnu++17", @@ -176,6 +196,16 @@ def build_native_library( ir_function_name=ir_function_name, kernel_signature=kernel_signature, ) + effective_insert_sync = _effective_insert_sync( + mode=module_spec.mode, + insert_sync=module_spec.insert_sync, + ) + effective_pto_level = _effective_pto_level(mode=module_spec.mode) + compile_config_text = _compile_config_text( + module_spec=module_spec, + effective_insert_sync=effective_insert_sync, + effective_pto_level=effective_pto_level, + ) sim_mode = bool(os.environ.get("MSPROF_SIMULATOR_MODE")) link_config_text = "\n".join(runtime_library_flags(sim_mode=sim_mode)) @@ -183,6 +213,7 @@ def build_native_library( artifacts, mlir_text=mlir_text, launch_cpp_text=launch_cpp_text, + compile_config_text=compile_config_text, link_config_text=link_config_text, ): return artifacts.shared_library, launch_symbol @@ -195,10 +226,8 @@ def build_native_library( artifacts.mlir_path, artifacts.kernel_object, target_arch=module_spec.target_arch, - insert_sync=_effective_insert_sync( - mode=module_spec.mode, - insert_sync=module_spec.insert_sync, - ), + insert_sync=effective_insert_sync, + pto_level=effective_pto_level, ) launch_object = artifacts.cache_dir / "launch.o" @@ -221,6 +250,7 @@ def build_native_library( launch_symbol=launch_symbol, mlir_digest=_content_digest(mlir_text), launch_cpp_digest=_content_digest(launch_cpp_text), + compile_config_digest=_content_digest(compile_config_text), link_config_digest=_content_digest(link_config_text), ) return artifacts.shared_library, launch_symbol diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index f15f62e39..6320fd371 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -31,6 +31,7 @@ class KernelModuleSpec: function_name: str target_arch: str kernel_kind: str + kernel_kind_explicit: bool = False backend: str = "vpto" entry: bool = True mode: str = "auto" diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index c2390f1d6..8f8183699 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -13,7 +13,7 @@ from dataclasses import dataclass import hashlib -from .._diagnostics import inline_subkernel_value_escape_error +from .._diagnostics import inline_subkernel_value_escape_error, subkernel_kernel_kind_mismatch_error from .._kernel_signature import RuntimeScalarParameterSpec from .._ops import const from .._surface_values import unwrap_surface_value, wrap_like_surface_value @@ -239,16 +239,50 @@ def _create_subkernel_section_op(self, role: str): return None def _create_inline_subkernel_wrapper(self, role: str): - wrapper_op = self._create_subkernel_section_op(role) + wrapper_op = None + if self._subkernel_section_policy(role) != "function_kind": + wrapper_op = self._create_subkernel_section_op(role) if wrapper_op is None: wrapper_op = _pto.VecScopeOp() body_block = wrapper_op.body.blocks.append() return wrapper_op, body_block + def _subkernel_role_kernel_kind(self, role: str) -> str | None: + if role == "simd": + return "vector" + if role == "cube": + return "cube" + return None + + def _current_explicit_kernel_kind(self) -> str | None: + module_spec = self.current_function_module_spec + if not getattr(module_spec, "kernel_kind_explicit", False): + return None + kind = getattr(module_spec, "kernel_kind", None) + return kind if kind in {"cube", "vector"} else None + + def _subkernel_section_policy(self, role: str) -> str: + role_kind = self._subkernel_role_kernel_kind(role) + explicit_kind = self._current_explicit_kernel_kind() + if role_kind is None or explicit_kind is None: + return "section" + if explicit_kind != role_kind: + raise subkernel_kernel_kind_mismatch_error(role, explicit_kind) + return "function_kind" + def _subkernel_helper_attributes(self, role: str) -> tuple[tuple[str, object], ...]: attrs: list[tuple[str, object]] = [] if role in {"simd", "cube"}: attrs.append(("pto.ptodsl.subkernel_helper", StringAttr.get(role))) + if self._subkernel_section_policy(role) == "function_kind": + attrs.append( + ( + "pto.kernel_kind", + Attribute.parse( + f"#pto.kernel_kind<{self._subkernel_role_kernel_kind(role)}>" + ), + ) + ) if role == "simt": attrs.append(("pto.simt_entry", UnitAttr.get())) return tuple(attrs) @@ -275,6 +309,10 @@ def enter_subkernel_body(self, role: str, symbol_name: str, target: str): ) self._subkernel_stack.append(frame) try: + if self._subkernel_section_policy(role) == "function_kind": + yield frame + return + section_op = self._create_subkernel_section_op(role) if section_op is None: yield frame @@ -391,7 +429,8 @@ def _remap_captured_operands(self, root_ops, capture_mapping) -> None: def _outline_inline_subkernel(self, outline_frame: InlineSubkernelOutlineFrame) -> None: role = outline_frame.trace_frame.role - if role in {"simd", "cube"}: + section_policy = self._subkernel_section_policy(role) + if role in {"simd", "cube"} and section_policy != "function_kind": root_ops = (outline_frame.wrapper_op,) else: root_ops = tuple(outline_frame.body_block.operations) @@ -422,7 +461,7 @@ def _outline_inline_subkernel(self, outline_frame: InlineSubkernelOutlineFrame) terminator = func.ReturnOp([]) return_anchor = terminator.operation.opview - if role in {"simd", "cube"}: + if role in {"simd", "cube"} and section_policy != "function_kind": outline_frame.wrapper_op.move_before(return_anchor) outlined_roots = (outline_frame.wrapper_op,) else: diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index fd1b39164..27619ccc9 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -150,6 +150,25 @@ def host_vec_copy_explicit( pto.tile.store(o_tile, out) +@pto.jit(target="a5", mode="explicit") +def host_vec_copy_explicit_addr( + A_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), + rows: pto.i32, + cols: pto.i32, + *, + BLOCK: pto.const_expr = 128, +): + a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) + o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, addr=0) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, addr=4096) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + @pto.jit(target="a5", backend="emitc") def host_vec_copy_emitc( A_ptr: pto.ptr(pto.f32, "gm"), @@ -669,6 +688,16 @@ def top_level_simd_probe(): SUBKERNEL_OBSERVATIONS.append((frame.role, frame.symbol_name, session.subkernel_stack_depth)) +@pto.simd +def explicit_vector_simd_probe(): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.cube +def explicit_vector_cube_probe(): + pto.pipe_barrier(pto.Pipe.ALL) + + @pto.jit(target="a5") def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): top_level_cube_probe() @@ -676,6 +705,22 @@ def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): nested_simd_probe() +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_calls_simd_probe(*, TRACE_TOKEN: pto.const_expr = 0): + explicit_vector_simd_probe() + + +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_calls_cube_probe(*, TRACE_TOKEN: pto.const_expr = 0): + explicit_vector_cube_probe() + + +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_inline_simd_probe(*, TRACE_TOKEN: pto.const_expr = 0): + with pto.simd(): + pto.pipe_barrier(pto.Pipe.ALL) + + @pto.jit(target="a5", mode="explicit") def inline_subkernel_scope_probe(*, TRACE_TOKEN: pto.const_expr = 0): session = current_session() @@ -3281,6 +3326,10 @@ def main() -> None: and helper_cache_signature[4] is False, "@pto.jit(entry=False) handles should expose an explicit, stable cache-signature protocol", ) + expect( + helper_cache_signature[7] == "vector" and helper_cache_signature[8] is False, + "default @pto.jit handles should keep vector as the effective kernel kind while recording that it was not explicit", + ) expect_raises( RuntimeError, kernel_module_return_probe.compile, @@ -3480,6 +3529,7 @@ def main() -> None: ) native_build_variants = ( ("pure-container", host_vec_copy.compile()), + ("explicit-level3-container", host_vec_copy_explicit_addr.compile()), ("same-backend-multi-child-container", kernel_module_compiled), ("mixed-backend-container", emitc_entry_calls_vpto_kernel_module_probe.compile()), ) @@ -3499,13 +3549,14 @@ def fake_artifacts(py_name, ir_function_name, specialization_key): manifest_path=cache_dir / "manifest.json", ) - def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None): + def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None, pto_level=None): native_build_observations.append( { "mlir_path": mlir_path, "kernel_object": kernel_object, "target_arch": target_arch, "insert_sync": insert_sync, + "pto_level": pto_level, "mlir_text": mlir_path.read_text(encoding="utf-8"), } ) @@ -3562,6 +3613,11 @@ def fake_link_shared_library(launch_object, kernel_object, shared_library, *, ke observation["insert_sync"] == expected_insert_sync, f"{label} native build should forward the effective insert_sync policy to ptoas", ) + expected_pto_level = "level3" if compiled._module_spec.mode == "explicit" else None + expect( + observation["pto_level"] == expected_pto_level, + f"{label} native build should derive the PTOAS level from the authored mode", + ) expect( observation["mlir_text"] == compiled.mlir_text(), f"{label} native build should hand the backend-partitioned container MLIR to ptoas unchanged", @@ -3601,7 +3657,7 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): ) expect( "--pto-level=level3" not in ptoas_cmd, - "native build should no longer reconstruct explicit mode through a global pto-level flag", + "native build should not pass a global pto-level flag by default", ) expect( "--enable-insert-sync" not in ptoas_cmd, @@ -3626,6 +3682,21 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "--enable-insert-sync" in ptoas_cmds[0], "native build should pass --enable-insert-sync when the compiled module explicitly requests it", ) + ptoas_cmds.clear() + with mock.patch.object(native_build_runtime, "resolve_ptoas_binary", return_value=Path("/tmp/fake-ptoas")), mock.patch.object( + native_build_runtime, "_run", side_effect=fake_run_ptoas_cmd + ): + native_build_runtime._run_ptoas( + mlir_path, + kernel_object, + target_arch="a5", + pto_level=native_build_runtime._effective_pto_level(mode="explicit"), + ) + expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command for explicit-mode PTOAS policy") + expect( + "--pto-level=level3" in ptoas_cmds[0], + 'native build should pass --pto-level=level3 for mode="explicit"', + ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") auto_mode_violation = expect_raises( RuntimeError, @@ -3859,6 +3930,34 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "outlined decorated helper bodies should still preserve their PTO unit sections", ) + explicit_vector_simd_text = explicit_vector_calls_simd_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify( + explicit_vector_simd_text, + "explicit vector jit calling simd subkernel specialization", + ) + expect( + "pto.kernel_kind = #pto.kernel_kind" in explicit_vector_simd_text + and "pto.section.vector {" not in explicit_vector_simd_text, + "same-kind @pto.simd helpers inside explicit vector kernels should use function/kernel kind metadata without redundant sections", + ) + expect_raises( + RuntimeError, + lambda: explicit_vector_calls_cube_probe.compile(TRACE_TOKEN=1).mlir_text(), + "@pto.cube cannot be lowered inside an explicit @pto.jit(kernel_kind='vector')", + ) + explicit_vector_inline_simd_text = explicit_vector_inline_simd_probe.compile( + TRACE_TOKEN=1 + ).mlir_text() + expect_parse_roundtrip_and_verify( + explicit_vector_inline_simd_text, + "explicit vector jit calling inline simd specialization", + ) + expect( + "pto.kernel_kind = #pto.kernel_kind" in explicit_vector_inline_simd_text + and "pto.section.vector {" not in explicit_vector_inline_simd_text, + "same-kind inline pto.simd() scopes inside explicit vector kernels should avoid redundant sections", + ) + INLINE_SUBKERNEL_SCOPE_OBSERVATIONS.clear() inline_subkernel_scope_text = inline_subkernel_scope_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(inline_subkernel_scope_text, "inline subkernel scope specialization") diff --git a/ptodsl/tests/test_ptoas_frontend_verify.py b/ptodsl/tests/test_ptoas_frontend_verify.py index 608bb9a0e..7b8a02df0 100644 --- a/ptodsl/tests/test_ptoas_frontend_verify.py +++ b/ptodsl/tests/test_ptoas_frontend_verify.py @@ -151,14 +151,24 @@ def run_ptoas_frontend_verify(ptoas_bin: Path, mlir_text: str, label: str) -> li return frontend_texts -def run_ptoas_frontend_verify_whole(ptoas_bin: Path, mlir_text: str, label: str) -> str: +def run_ptoas_frontend_verify_whole( + ptoas_bin: Path, + mlir_text: str, + label: str, + *, + extra_ptoas_args: list[str] | None = None, +) -> str: with tempfile.NamedTemporaryFile("w", suffix=".mlir", delete=False, encoding="utf-8") as handle: handle.write(mlir_text) input_path = Path(handle.name) try: + cmd = [str(ptoas_bin), str(input_path)] + if extra_ptoas_args: + cmd.extend(extra_ptoas_args) + cmd.extend(["--emit-pto-ir", "-o", "-"]) result = subprocess.run( - [str(ptoas_bin), str(input_path), "--emit-pto-ir", "-o", "-"], + cmd, capture_output=True, text=True, check=False, @@ -174,6 +184,19 @@ def run_ptoas_frontend_verify_whole(ptoas_bin: Path, mlir_text: str, label: str) return result.stdout +def expect_no_raw_partition_tensor_view(frontend_text: str, label: str) -> None: + expect( + "!pto.partition_tensor_view" not in frontend_text, + f"{label} should not leak raw !pto.partition_tensor_view into PTOAS frontend output.\n" + f"frontend output:\n{frontend_text}", + ) + expect( + "memref.subview" in frontend_text or "memref.reinterpret_cast" in frontend_text, + f"{label} should materialize memref-backed view lowering in PTOAS frontend output.\n" + f"frontend output:\n{frontend_text}", + ) + + def run_ptoas_frontend_expect_failure( ptoas_bin: Path, mlir_text: str, @@ -223,6 +246,21 @@ def host_vec_copy( pto.tile.store(o_tile, out) +@pto.jit(target="a5", mode="explicit", insert_sync=False) +def explicit_addr_vec_copy( + A_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), +): + a_view = pto.make_tensor_view(A_ptr, shape=[1, 1, 1, 1, 64], strides=[64, 64, 64, 64, 1]) + o_view = pto.make_tensor_view(O_ptr, shape=[1, 1, 1, 1, 64], strides=[64, 64, 64, 64, 1]) + a_tile = pto.alloc_tile(shape=[1, 64], dtype=pto.f32, addr=0, valid_shape=[1, 64], blayout="RowMajor") + o_tile = pto.alloc_tile(shape=[1, 64], dtype=pto.f32, addr=2048, valid_shape=[1, 64], blayout="RowMajor") + part = pto.partition_view(a_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, 64]) + out = pto.partition_view(o_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, 64]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + @pto.simt def simt_gm_memory_core_body(gm: pto.ptr(pto.i32, "gm")): tx = pto.get_tid_x() @@ -339,6 +377,43 @@ def main() -> None: "pto.tload" in simple_frontend_text and "pto.tstore" in simple_frontend_text, "host_vec_copy frontend verification output should keep the tile IO contract visible", ) + simple_whole_frontend_text = run_ptoas_frontend_verify_whole( + ptoas_bin, + simple_text, + "host_vec_copy PTODSL whole-container artifact", + ) + expect( + "func.func @host_vec_copy" in simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification should preserve the kernel symbol", + ) + expect( + "pto.tload" in simple_whole_frontend_text and "pto.tstore" in simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification should keep the tile IO contract visible", + ) + expect_no_raw_partition_tensor_view( + simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification", + ) + + explicit_whole_text = explicit_addr_vec_copy.compile().mlir_text() + explicit_whole_frontend_text = run_ptoas_frontend_verify_whole( + ptoas_bin, + explicit_whole_text, + "explicit_addr_vec_copy PTODSL whole-container artifact", + extra_ptoas_args=["--pto-level=level3"], + ) + expect( + "func.func @explicit_addr_vec_copy" in explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification should preserve the kernel symbol", + ) + expect( + "pto.tload" in explicit_whole_frontend_text and "pto.tstore" in explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification should keep the tile IO contract visible", + ) + expect_no_raw_partition_tensor_view( + explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification", + ) simt_gm_memory_text = simt_gm_memory_core_kernel.compile().mlir_text() simt_frontend_texts = run_ptoas_frontend_verify( diff --git a/scripts/sim_dsl.sh b/scripts/sim_dsl.sh index cd0bccbdf..5632fb2e9 100755 --- a/scripts/sim_dsl.sh +++ b/scripts/sim_dsl.sh @@ -35,6 +35,8 @@ Environment: Keep the private staging directory after a successful sync. PTOAS_MSPROF_LOG_MODE=quiet|verbose Override the default simulator log rendering mode. + PYTHON_BIN Python executable used for the PTODSL example. + Defaults to python3. Examples: scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py @@ -177,12 +179,32 @@ ensure_private_dir "${PRIVATE_ROOT}" RUNTIME_OUTPUT_DIR="$(mktemp -d "${PRIVATE_ROOT}/${EXAMPLE_STEM}.XXXXXX")" chmod 700 "${RUNTIME_OUTPUT_DIR}" MSPROF_STDIO_LOG="${RUNTIME_OUTPUT_DIR}/msprof.stdout.log" +EXAMPLE_EXIT_CODE_FILE="${RUNTIME_OUTPUT_DIR}/example.exitcode" +EXAMPLE_LAUNCHER="${RUNTIME_OUTPUT_DIR}/run_example.sh" +PYTHON_BIN="${PTO_PYTHON_BIN:-${PYTHON_BIN:-python3}}" source "${ASCEND_HOME_PATH}/bin/setenv.bash" source "${REPO_ROOT}/scripts/ptoas_env.sh" export LD_LIBRARY_PATH="${SIM_LIB_DIR}:${LD_LIBRARY_PATH:-}" ulimit -n 65535 +if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then + die "PYTHON_BIN is not executable or not found on PATH: ${PYTHON_BIN}" +fi + +cat > "${EXAMPLE_LAUNCHER}" <<'EOF' +#!/usr/bin/env bash +set +e +"${PTOAS_SIM_DSL_PYTHON_BIN}" "${PTOAS_SIM_DSL_EXAMPLE_PATH}" "$@" +status=$? +printf '%s\n' "${status}" > "${PTOAS_SIM_DSL_EXIT_CODE_FILE}" +exit "${status}" +EOF +chmod 700 "${EXAMPLE_LAUNCHER}" +export PTOAS_SIM_DSL_PYTHON_BIN="${PYTHON_BIN}" +export PTOAS_SIM_DSL_EXAMPLE_PATH="${EXAMPLE_PATH}" +export PTOAS_SIM_DSL_EXIT_CODE_FILE="${EXAMPLE_EXIT_CODE_FILE}" + # msprof rejects group/other-writable working directories, so always launch # from a private directory and use an absolute path for the example script. cd "${HOME}" @@ -196,11 +218,30 @@ set +e msprof op simulator \ --soc-version="${SOC_VERSION}" \ --output="${RUNTIME_OUTPUT_DIR}" \ - python3 "${EXAMPLE_PATH}" "${EXAMPLE_ARGS[@]}" \ + "${EXAMPLE_LAUNCHER}" "${EXAMPLE_ARGS[@]}" \ > "${MSPROF_STDIO_LOG}" 2>&1 -STATUS=$? +MSPROF_STATUS=$? set -e +EXAMPLE_STATUS=0 +if [[ -f "${EXAMPLE_EXIT_CODE_FILE}" ]]; then + EXAMPLE_STATUS="$(< "${EXAMPLE_EXIT_CODE_FILE}")" + if [[ ! "${EXAMPLE_STATUS}" =~ ^[0-9]+$ ]]; then + log "invalid example exit code recorded in ${EXAMPLE_EXIT_CODE_FILE}: ${EXAMPLE_STATUS}" + EXAMPLE_STATUS=1 + fi +else + log "example exit code file was not produced: ${EXAMPLE_EXIT_CODE_FILE}" + EXAMPLE_STATUS=1 +fi + +STATUS=0 +if [[ ${MSPROF_STATUS} -ne 0 ]]; then + STATUS=${MSPROF_STATUS} +elif [[ ${EXAMPLE_STATUS} -ne 0 ]]; then + STATUS=${EXAMPLE_STATUS} +fi + print_msprof_log "${MSPROF_STDIO_LOG}" "${MSPROF_LOG_MODE}" "${STATUS}" SYNC_STATUS=0 diff --git a/test/dsl-st/cube_matrix_pipeline.py b/test/dsl-st/cube_matrix_pipeline.py index 420d235e7..0138fae35 100644 --- a/test/dsl-st/cube_matrix_pipeline.py +++ b/test/dsl-st/cube_matrix_pipeline.py @@ -21,50 +21,19 @@ M = 16 K = 32 -N = 48 +N = 64 +ELEM_BYTES = 4 L1_A_ADDR = 0 L1_B_ADDR = 4096 -UB_O_ADDR = 0 L0A_ADDR = 0 L0B_ADDR = 0 L0C_ADDR = 0 -@pto.cube -def cube_gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc): - m = a_mat.valid_shape[0] - k = a_mat.valid_shape[1] - n = b_mat.valid_shape[1] - - pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) - pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n) - pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) - pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) - pto.mad( - a_l0a.as_ptr(), - b_l0b.as_ptr(), - o_acc.as_ptr(), - m, - n, - k, - unit_flag=pto.MadUnitFlagMode.CHECK_ONLY, - sat=pto.SatMode.OFF, - ) - pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) - pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) - pto.mte_l0c_ub( - o_acc.as_ptr(), - o_tile.as_ptr(), - m, - n, - n, - n, - ) - - @pto.jit( name="cube_matrix_pipeline_kernel", + kernel_kind="cube", target="a5", mode="explicit", insert_sync=False, @@ -74,20 +43,14 @@ def cube_matrix_pipeline_kernel( b_ptr: pto.ptr(pto.f32, "gm"), o_ptr: pto.ptr(pto.f32, "gm"), ): - a_view = pto.make_tensor_view(a_ptr, shape=[M, K], strides=[K, 1]) - b_view = pto.make_tensor_view(b_ptr, shape=[K, N], strides=[N, 1]) - o_view = pto.make_tensor_view(o_ptr, shape=[M, N], strides=[N, 1]) - - a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[M, K]) - b_part = pto.partition_view(b_view, offsets=[0, 0], sizes=[K, N]) - o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[M, N]) - a_mat = pto.alloc_tile( shape=[M, K], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, addr=L1_A_ADDR, valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", ) b_mat = pto.alloc_tile( shape=[K, N], @@ -95,12 +58,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.MAT, addr=L1_B_ADDR, valid_shape=[K, N], - ) - o_tile = pto.alloc_tile( - shape=[M, N], - dtype=pto.f32, - addr=UB_O_ADDR, - valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", ) a_l0a = pto.alloc_tile( shape=[M, K], @@ -108,6 +67,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.LEFT, addr=L0A_ADDR, valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", ) b_l0b = pto.alloc_tile( shape=[K, N], @@ -115,6 +76,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.RIGHT, addr=L0B_ADDR, valid_shape=[K, N], + blayout="RowMajor", + slayout="ColMajor", ) o_acc = pto.alloc_tile( shape=[M, N], @@ -122,16 +85,58 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.ACC, addr=L0C_ADDR, valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", + fractal_size=1024, ) - pto.tile.load(a_part, a_mat) - pto.tile.load(b_part, b_mat) + a_l1_ptr = pto.castptr(pto.ui64(L1_A_ADDR), pto.ptr(pto.f32, "mat")) + b_l1_ptr = pto.castptr(pto.ui64(L1_B_ADDR), pto.ptr(pto.f32, "mat")) + + pto.mte_gm_l1_frac( + a_ptr, + a_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(M, K), + src_layout=(K * ELEM_BYTES,), + dst_group=(1, 1, M, 0), + ctrl=(0, False), + ) pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) - cube_gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc) - pto.set_flag(pto.Pipe.FIX, pto.Pipe.MTE3, event_id=2) - pto.wait_flag(pto.Pipe.FIX, pto.Pipe.MTE3, event_id=2) - pto.tile.store(o_tile, o_part) + pto.mte_l1_l0a(a_l1_ptr, a_l0a.as_ptr(), M, K) + + pto.mte_gm_l1_frac( + b_ptr, + b_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(K, N), + src_layout=(N * ELEM_BYTES,), + dst_group=(1, 1, K, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.mte_l1_l0b(b_l1_ptr, b_l0b.as_ptr(), K, N, transpose=True) + + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.tile.matmul(a_l0a, b_l0b, o_acc) + + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.mte_l0c_gm( + o_acc.as_ptr(), + o_ptr, + M, + N, + M, + N, + 0, + 0, + layout="nz2nd", + ) + pto.pipe_barrier(pto.Pipe.ALL) def make_inputs(): diff --git a/test/dsl-st/gemv_mx_pipeline.py b/test/dsl-st/gemv_mx_pipeline.py index 89b1d80f7..3d8b0f5f5 100644 --- a/test/dsl-st/gemv_mx_pipeline.py +++ b/test/dsl-st/gemv_mx_pipeline.py @@ -17,8 +17,6 @@ from common import assert_close, auto_main from ptodsl import pto -from ptodsl._surface_values import unwrap_surface_value -from mlir.dialects import pto as _pto M = 1 @@ -177,17 +175,6 @@ def _alloc_common_tiles(): return lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile -def _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile): - _pto.TGetScaleAddrOp( - unwrap_surface_value(lhs_tile), - unwrap_surface_value(lhs_scale_tile), - ) - _pto.TGetScaleAddrOp( - unwrap_surface_value(rhs_tile), - unwrap_surface_value(rhs_scale_tile), - ) - - def _alloc_bias_tile(): return pto.alloc_tile( shape=[M, N_STORAGE], @@ -271,7 +258,6 @@ def gemv_mx_fp8_pipeline_kernel( ): lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile = _alloc_common_tiles() _stage_fp8_tiles(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, lhs_tile, rhs_tile) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) _writeback_output(dst_tile, out_ptr) @@ -292,7 +278,6 @@ def gemv_mx_acc_fp8_pipeline_kernel( ): lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile = _alloc_common_tiles() _stage_fp8_tiles(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, lhs_tile, rhs_tile) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) pto.tile.gemv_mx_acc(dst_tile, lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) _writeback_output(dst_tile, out_ptr) @@ -325,7 +310,6 @@ def gemv_mx_bias_fp8_pipeline_kernel( bias_ptr=bias_ptr, bias_tile=bias_tile, ) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx_bias(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, bias_tile, dst_tile) _writeback_output(dst_tile, out_ptr) diff --git a/test/dsl-st/npu_a5/__main__.py b/test/dsl-st/npu_a5/__main__.py new file mode 100644 index 000000000..03ec12f0a --- /dev/null +++ b/test/dsl-st/npu_a5/__main__.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Directory runner for the npu_a5 operator-level ST cases. +# +# Each operator (e.g. tadd, tmatmul, ...) is a single *.py file that authors its +# kernel with PTODSL and builds its CASES list through the helpers in +# ``test/dsl-st/common.py``. Running this directory discovers every *.py module +# and executes the cases against the torch_npu / simulator runtime. +# +# See test/dsl-st/README.md for the authoring conventions shared with the rest +# of the dsl-st suite. + +from pathlib import Path +import sys + + +if __package__ in {None, ""}: + # common.py lives one level up, in test/dsl-st/. + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import run_discovered_cases + + +if __name__ == "__main__": + raise SystemExit(run_discovered_cases(Path(__file__).resolve().parent)) diff --git a/test/dsl-st/npu_a5/tadd.py b/test/dsl-st/npu_a5/tadd.py new file mode 100644 index 000000000..7df2bfaea --- /dev/null +++ b/test/dsl-st/npu_a5/tadd.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# PTODSL rewrite of test/tilelang_st/npu/a5/src/st/testcase/tadd. +# +# Keep the kernel body close to the original semantic tile-op authoring: +# tload(a) + tload(b) + tadd(a,b)->c + tstore(c) +# +# The case uses explicit UB addresses to match the hand-authored ST contract. +# PTODSL native build derives PTOAS level3 from mode="explicit". + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +# Each case is (name, shape). Both use fully-valid f32 tiles, matching the +# original tadd cases "f32_16x64" and "f32_32x32". +CASE_SHAPES = [ + ("f32_16x64", (16, 64)), + ("f32_32x32", (32, 32)), +] + +A_TILE_ADDR = 0 +B_TILE_ADDR = 4096 +C_TILE_ADDR = 8192 + + +class _FlatKernelHandle: + """Small test-local wrapper that compiles through a flat PTODSL container.""" + + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + """Mirror @pto.jit for this testcase, but force a flat module container.""" + + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +def _tadd_body(a_ptr, b_ptr, c_ptr, *, rows, cols): + """Shared kernel body for the two tadd cases.""" + + # Keep the original 5D tilelang-style partition schema here. It matches the + # hand-authored tadd.pto layout and is already known-good for vec tile-op + # ST cases in this repository. + total = rows * cols + a_view = pto.make_tensor_view(a_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + b_view = pto.make_tensor_view(b_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + c_view = pto.make_tensor_view(c_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + + a_part = pto.partition_view(a_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + b_part = pto.partition_view(b_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + c_part = pto.partition_view(c_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + + # Use explicit UB addresses so direct level3 VPTO lowering has no memory + # planning dependency. + a_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=A_TILE_ADDR) + b_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=B_TILE_ADDR) + c_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=C_TILE_ADDR) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.tile.add(a_tile, b_tile, c_tile) + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.tile.store(c_tile, c_part) + pto.pipe_barrier(pto.Pipe.ALL) + + +# One decorated kernel per case, each binding a static shape at definition time +# (mirroring the per-case funcs in tadd.pto). +_tadd_kernels = {} +for _name, _shape in CASE_SHAPES: + _r, _c = _shape + + def _make(r=_r, c=_c): + @_flat_jit( + name=f"tadd_{_name}", + kernel_kind="vector", + target="a5", + ) + def _kernel( + a_ptr: pto.ptr(pto.f32, "gm"), + b_ptr: pto.ptr(pto.f32, "gm"), + c_ptr: pto.ptr(pto.f32, "gm"), + ): + _tadd_body(a_ptr, b_ptr, c_ptr, rows=r, cols=c) + + return _kernel + + _tadd_kernels[_name] = _make() + + +def _make_inputs(name, shape): + # Deterministic per-case seed, mirroring st_common.setup_case_rng which uses + # crc32(name). Original value range was randint(1, 10). + import zlib + np.random.seed(zlib.crc32(name.encode("utf-8")) & 0xFFFFFFFF) + a = np.random.randint(1, 10, size=shape).astype(np.float32) + b = np.random.randint(1, 10, size=shape).astype(np.float32) + return [a, b] + + +def _make_expected(a, b): + return (a + b).astype(np.float32) + + +CASES = [] +for _name, _shape in CASE_SHAPES: + CASES.append( + golden_output_case( + "tadd_" + _name, + _tadd_kernels[_name], + inputs=lambda _name=_name, _shape=_shape: _make_inputs(_name, _shape), + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ) + ) + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(*[kernel.compile() for kernel in _tadd_kernels.values()]) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tcolexpand.py b/test/dsl-st/npu_a5/tcolexpand.py new file mode 100644 index 000000000..4d30dd5db --- /dev/null +++ b/test/dsl-st/npu_a5/tcolexpand.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL broadcast pilot for A5: +# tload(src) + tcolexpand(src)->dst + tstore(dst) + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +SRC_ROWS = 1 +DST_ROWS = 8 +COLS = 128 +SRC_TILE_ADDR = 0 +DST_TILE_ADDR = 4096 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +@_flat_jit( + name="tcolexpand_f32_1x8x128", + kernel_kind="vector", + target="a5", +) +def _tcolexpand_kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), +): + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, SRC_ROWS, COLS], + strides=[COLS, COLS, COLS, COLS, 1], + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, DST_ROWS, COLS], + strides=[DST_ROWS * COLS, DST_ROWS * COLS, DST_ROWS * COLS, COLS, 1], + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, SRC_ROWS, COLS]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, DST_ROWS, COLS]) + + src_tile = pto.alloc_tile(shape=[SRC_ROWS, COLS], dtype=pto.f32, addr=SRC_TILE_ADDR) + dst_tile = pto.alloc_tile(shape=[DST_ROWS, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) + + pto.tile.load(src_part, src_tile) + pto.tile.colexpand(src_tile, dst_tile) + pto.tile.store(dst_tile, dst_part) + + +def _make_input(): + rng = np.random.default_rng(0xC01E0A5) + return rng.uniform(-2.0, 2.0, size=(SRC_ROWS, COLS)).astype(np.float32) + + +def _make_expected(src): + return np.repeat(src, DST_ROWS, axis=0).astype(np.float32) + + +CASES = [ + golden_output_case( + "tcolexpand_f32_1x8x128", + _tcolexpand_kernel, + inputs=lambda: [_make_input()], + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ), +] + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(_tcolexpand_kernel.compile()) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tcolsum.py b/test/dsl-st/npu_a5/tcolsum.py new file mode 100644 index 000000000..a6d898e46 --- /dev/null +++ b/test/dsl-st/npu_a5/tcolsum.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL reduction pilot for A5: +# tload(src) + tcolsum(src)->dst + tstore(dst) + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +ROWS = 16 +COLS = 128 +SRC_TILE_ADDR = 0 +DST_TILE_ADDR = 8192 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +@_flat_jit( + name="tcolsum_f32_16x128", + kernel_kind="vector", + target="a5", +) +def _tcolsum_kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), +): + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, ROWS, COLS], + strides=[ROWS * COLS, ROWS * COLS, ROWS * COLS, COLS, 1], + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, 1, COLS], + strides=[COLS, COLS, COLS, COLS, 1], + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, ROWS, COLS]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, COLS]) + + src_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32, addr=SRC_TILE_ADDR) + dst_tile = pto.alloc_tile(shape=[1, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) + + pto.tile.load(src_part, src_tile) + pto.tile.colsum(src_tile, dst_tile) + pto.tile.store(dst_tile, dst_part) + + +def _make_input(): + rng = np.random.default_rng(0xC01A5EED) + return rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + + +def _make_expected(src): + return np.sum(src, axis=0, keepdims=True, dtype=np.float32) + + +CASES = [ + golden_output_case( + "tcolsum_f32_16x128", + _tcolsum_kernel, + inputs=lambda: [_make_input()], + expected=_make_expected, + rtol=1e-5, + atol=1e-5, + ), +] + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(_tcolsum_kernel.compile()) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tload_store.py b/test/dsl-st/npu_a5/tload_store.py new file mode 100644 index 000000000..595ef1518 --- /dev/null +++ b/test/dsl-st/npu_a5/tload_store.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# PTODSL rewrite of the minimal GM -> tile -> GM coverage from +# test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto. +# +# Start with two static f32 round-trips: +# 1. ND / row-major +# 2. DN / col-major +# These are the smallest data-movement cases needed to validate that PTODSL can +# drive tload/tstore on A5 without the tilelang_st harness. + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +CASE_SPECS = [ + { + "case_name": "nd_f32_16x64", + "kernel_name": "tload_store_nd_f32_16x64", + "shape": (16, 64), + "view_strides": None, + "tile_kwargs": {}, + }, + { + "case_name": "dn_f32_16x64", + "kernel_name": "tload_store_dn_f32_16x64", + "shape": (16, 64), + "view_strides": None, + "tile_kwargs": {"blayout": "ColMajor"}, + }, +] + +TILE_ADDR = 0 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +def _roundtrip_body(src_ptr, dst_ptr, *, rows, cols, view_strides=None, tile_kwargs=None): + total = rows * cols + if view_strides is None: + view_strides = [total, total, total, cols, 1] + + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, rows, cols], + strides=view_strides, + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, rows, cols], + strides=view_strides, + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + + tile = pto.alloc_tile( + shape=[rows, cols], + dtype=pto.f32, + addr=TILE_ADDR, + **(tile_kwargs or {}), + ) + + pto.tile.load(src_part, tile) + pto.tile.store(tile, dst_part) + + +_tload_store_kernels = {} +for _spec in CASE_SPECS: + _rows, _cols = _spec["shape"] + _view_strides = _spec["view_strides"] + if _view_strides is None and _spec["tile_kwargs"].get("blayout") == "ColMajor": + _view_strides = [_rows * _cols, _rows * _cols, _rows * _cols, 1, _rows] + _tile_kwargs = dict(_spec["tile_kwargs"]) + _kernel_name = _spec["kernel_name"] + _case_name = _spec["case_name"] + + def _make(rows=_rows, cols=_cols, view_strides=_view_strides, tile_kwargs=_tile_kwargs, kernel_name=_kernel_name): + @_flat_jit( + name=kernel_name, + kernel_kind="vector", + target="a5", + ) + def _kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), + ): + _roundtrip_body( + src_ptr, + dst_ptr, + rows=rows, + cols=cols, + view_strides=view_strides, + tile_kwargs=tile_kwargs, + ) + + return _kernel + + _tload_store_kernels[_case_name] = _make() + + +def _make_input(name, shape): + import zlib + + np.random.seed(zlib.crc32(name.encode("utf-8")) & 0xFFFFFFFF) + return np.random.randint(1, 32, size=shape).astype(np.float32) + + +def _make_expected(src): + return np.asarray(src, dtype=np.float32).copy() + + +CASES = [] +for _spec in CASE_SPECS: + _case_name = _spec["case_name"] + _shape = _spec["shape"] + CASES.append( + golden_output_case( + "tload_store_" + _case_name, + _tload_store_kernels[_case_name], + inputs=lambda _case_name=_case_name, _shape=_shape: [_make_input(_case_name, _shape)], + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ) + ) + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(*[kernel.compile() for kernel in _tload_store_kernels.values()]) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tmatmul.py b/test/dsl-st/npu_a5/tmatmul.py new file mode 100644 index 000000000..df2c5fc06 --- /dev/null +++ b/test/dsl-st/npu_a5/tmatmul.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL cube/tmatmul pilot for A5. +# Goal: validate plain cube tile.matmul lowering/runtime first, without mixing +# MX-specific scale/bias handling or @pto.cube helper boundaries. + +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto + + +M = 16 +K = 32 +N = 64 +ELEM_BYTES = 4 + +L1_A_ADDR = 0 +L1_B_ADDR = 4096 +L0A_ADDR = 0 +L0B_ADDR = 0 +L0C_ADDR = 0 + + +@pto.cube +def cube_matmul_tile( + a_mat: pto.Tile, + b_mat: pto.Tile, + o_tile: pto.Tile, + a_l0a: pto.Tile, + b_l0b: pto.Tile, + c_acc: pto.Tile, +): + m = a_mat.valid_shape[0] + k = a_mat.valid_shape[1] + n = b_mat.valid_shape[1] + + pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n, transpose=True) + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=1) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=1) + pto.tile.matmul(a_l0a, b_l0b, c_acc) + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=2) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=2) + pto.mte_l0c_ub( + c_acc.as_ptr(), + o_tile.as_ptr(), + m, + n, + n, + n, + ) + + +@pto.jit( + name="tmatmul_f32_16x32x64", + kernel_kind="cube", + target="a5", + mode="explicit", + insert_sync=False, +) +def _tmatmul_kernel( + a_ptr: pto.ptr(pto.f32, "gm"), + b_ptr: pto.ptr(pto.f32, "gm"), + c_ptr: pto.ptr(pto.f32, "gm"), +): + a_mat = pto.alloc_tile( + shape=[M, K], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + addr=L1_A_ADDR, + valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", + ) + b_mat = pto.alloc_tile( + shape=[K, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + addr=L1_B_ADDR, + valid_shape=[K, N], + blayout="ColMajor", + slayout="RowMajor", + ) + a_l0a = pto.alloc_tile( + shape=[M, K], + dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, + addr=L0A_ADDR, + valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", + ) + b_l0b = pto.alloc_tile( + shape=[K, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.RIGHT, + addr=L0B_ADDR, + valid_shape=[K, N], + blayout="RowMajor", + slayout="ColMajor", + ) + c_acc = pto.alloc_tile( + shape=[M, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, + addr=L0C_ADDR, + valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", + fractal_size=1024, + ) + + a_l1_ptr = pto.castptr(pto.ui64(L1_A_ADDR), pto.ptr(pto.f32, "mat")) + b_l1_ptr = pto.castptr(pto.ui64(L1_B_ADDR), pto.ptr(pto.f32, "mat")) + + pto.mte_gm_l1_frac( + a_ptr, + a_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(M, K), + src_layout=(K * ELEM_BYTES,), + dst_group=(1, 1, M, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) + pto.mte_l1_l0a(a_l1_ptr, a_l0a.as_ptr(), M, K) + + pto.mte_gm_l1_frac( + b_ptr, + b_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(K, N), + src_layout=(N * ELEM_BYTES,), + dst_group=(1, 1, K, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.mte_l1_l0b(b_l1_ptr, b_l0b.as_ptr(), K, N, transpose=True) + + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.tile.matmul(a_l0a, b_l0b, c_acc) + + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.mte_l0c_gm( + c_acc.as_ptr(), + c_ptr, + M, + N, + M, + N, + 0, + 0, + layout="nz2nd", + ) + pto.pipe_barrier(pto.Pipe.ALL) + + +def _make_inputs(): + rng = np.random.default_rng(0x7A7A7A71) + a = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float32) + b = rng.uniform(-2.0, 2.0, size=(K, N)).astype(np.float32) + return [a, b] + + +def _make_expected(a, b): + return (a @ b).astype(np.float32) + + +CASES = [ + golden_output_case( + "tmatmul_f32_16x32x64", + _tmatmul_kernel, + inputs=_make_inputs, + expected=_make_expected, + rtol=1e-4, + atol=1e-4, + ), +] + + +auto_main(globals()) diff --git a/test/dsl-st/vmulscvt.py b/test/dsl-st/vmulscvt.py index 4eb6a6644..eaa236870 100644 --- a/test/dsl-st/vmulscvt.py +++ b/test/dsl-st/vmulscvt.py @@ -15,11 +15,12 @@ - `pto.vmulscvt(..., part=EVEN)` - `pto.vbitcast(..., pto.ui32)` - `pto.vpack(..., LOWER)` -- UB materialization via `pto.vsts` +- UB materialization via `pto.vsts` with a `PAT_VL64` mask -The observable is the packed `u16` register image after the `vmulscvt + vpack` +The observable is the lower 64-lane payload produced by the `vmulscvt + vpack` sequence. That keeps the test close to the C++ authoring style without relying -on `vsstb.post`, which is not available on the current PTODSL surface yet. +on `vsstb.post`, which is not available on the current PTODSL surface yet, and +without asserting anything about lanes outside the authored packed payload. """ from pathlib import Path @@ -35,7 +36,7 @@ SRC_COLS = 64 -OUT_COLS = 128 +OUT_COLS = 64 SCALE = -0.5 @@ -87,7 +88,7 @@ def vmulscvt_pack_kernel( with pto.simd(): mask32 = pto.pset_b32(pto.MaskPattern.ALL) - mask16 = pto.pset_b16(pto.MaskPattern.ALL) + mask16 = pto.pset_b16(pto.MaskPattern.VL64) src = pto.vlds(src_tile[0, 0:]) packed_f16 = pto.vmulscvt( @@ -115,9 +116,7 @@ def make_inputs(): def make_expected(inp): scaled = (inp.astype(np.float32) * np.float32(SCALE)).astype(np.float16).reshape(-1) - packed = np.zeros((OUT_COLS,), dtype=np.uint16) - packed[:SRC_COLS] = scaled.view(np.uint16) - return packed.reshape(1, OUT_COLS) + return scaled.view(np.uint16).reshape(1, OUT_COLS) CASES = [ diff --git a/tools/ptoas/driver.cpp b/tools/ptoas/driver.cpp index b8de95aa2..fcb919738 100644 --- a/tools/ptoas/driver.cpp +++ b/tools/ptoas/driver.cpp @@ -937,12 +937,28 @@ LogicalResult EmitCBackendJob::run(PTOASContext &context) { } LogicalResult VPTOBackendJob::run(PTOASContext &context) { + OwningOpRef singleChildJobModule; + OwningOpRef *compileUnit = &module; ModuleOp op = module.get(); op->setAttr("pto.backend", StringAttr::get(op.getContext(), "vpto")); + SmallVector children(op.getOps()); + if (children.size() == 1 && isBackendPartitionedContainer(op)) { + FailureOr> jobModuleOr = + buildBackendChildCompileUnit(op, children.front()); + if (failed(jobModuleOr)) + return failure(); + singleChildJobModule = std::move(*jobModuleOr); + singleChildJobModule.get()->setAttr( + "pto.backend", + StringAttr::get(singleChildJobModule.get()->getContext(), "vpto")); + compileUnit = &singleChildJobModule; + op = singleChildJobModule.get(); + } + bool emitHostStub = hasPTOEntry(op); if (mlir::pto::compilePTOASModule( - module, context, mlir::pto::PTOBackend::VPTO, result, + *compileUnit, context, mlir::pto::PTOBackend::VPTO, result, emitHostStub) != 0) return failure(); if (result.kind == mlir::pto::PTOASCompileResultKind::Text) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index d5ba2b131..a310fe8b3 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1769,6 +1769,16 @@ static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); + pto::PTOViewToMemrefOptions viewOnlyRerunOpts; + viewOnlyRerunOpts.viewOnly = true; + // ExpandTileOp materializes fresh TileLang helper IR after the shared + // mainline has already lowered the original module's tensor_view surface. + // Re-run the shared view lowering here so helper-local + // pto.make_tensor_view/pto.partition_view chains do not leak into + // FoldTileBufIntrinsics. + kernelModulePM.addPass(pto::createPTOViewToMemrefPass(viewOnlyRerunOpts)); + kernelModulePM.addPass(mlir::createCanonicalizerPass()); + kernelModulePM.addPass(mlir::createCSEPass()); kernelModulePM.addNestedPass( pto::createFoldTileBufIntrinsicsPass("shape-only")); if (enableA5VPTOPostLoweringFusionLifecycle) {