From 44a25ba9a8f8257ad5fe0088f66f6ab55d5cd668 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Tue, 30 Jun 2026 20:19:22 +0800 Subject: [PATCH 01/10] Add PTODSL A5 DSL ST coverage --- include/PTO/Transforms/Passes.td | 24 +- lib/PTO/Transforms/ExpandTileOp.cpp | 136 ++++++- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 382 ++++++++++++++---- .../PTOInstantiateAndInlineOpLib.cpp | 18 +- ptodsl/docs/user_guide/01-introduction.md | 6 +- .../03-kernel-entry-and-subkernels.md | 6 +- ptodsl/ptodsl/_diagnostics.py | 10 + ptodsl/ptodsl/_jit.py | 28 +- ptodsl/ptodsl/_runtime/cache.py | 4 + ptodsl/ptodsl/_runtime/native_build.py | 51 ++- ptodsl/ptodsl/_tracing/module_builder.py | 1 + ptodsl/ptodsl/_tracing/session.py | 47 ++- ptodsl/tests/test_jit_compile.py | 98 ++++- test/dsl-st/cube_matrix_pipeline.py | 113 +++--- test/dsl-st/gemv_mx_pipeline.py | 16 - test/dsl-st/npu_a5/__main__.py | 32 ++ test/dsl-st/npu_a5/tadd.py | 197 +++++++++ test/dsl-st/npu_a5/tcolexpand.py | 151 +++++++ test/dsl-st/npu_a5/tcolsum.py | 150 +++++++ test/dsl-st/npu_a5/tload_store.py | 209 ++++++++++ test/dsl-st/npu_a5/tmatmul.py | 199 +++++++++ test/dsl-st/vmulscvt.py | 15 +- 22 files changed, 1693 insertions(+), 200 deletions(-) create mode 100644 test/dsl-st/npu_a5/__main__.py create mode 100644 test/dsl-st/npu_a5/tadd.py create mode 100644 test/dsl-st/npu_a5/tcolexpand.py create mode 100644 test/dsl-st/npu_a5/tcolsum.py create mode 100644 test/dsl-st/npu_a5/tload_store.py create mode 100644 test/dsl-st/npu_a5/tmatmul.py diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 63b06b6dbf..7c43e4af2a 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -507,17 +507,19 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu - pto.tile_valid_cols → same as above for v_col tensor_view family: - - pto.tensor_view_addr → traces through unrealized_conversion_cast → - subview → reinterpret_cast, then folds to the base memref or to - pto.castptr/pto.addptr on the base memref - - pto.get_tensor_view_dim → folded to arith.constant for static subview - sizes, or to the subview size SSA operand for dynamic dims - - pto.get_tensor_view_stride → folded to the reinterpret_cast stride - operand, multiplied by the subview stride when needed - - Dead unrealized_conversion_cast, memref.subview, and - memref.reinterpret_cast ops exposed by folding are cleaned up after the - rewrite. + - pto.tensor_view_addr → traces through either + unrealized_conversion_cast → subview → reinterpret_cast or native + pto.partition_view → pto.make_tensor_view, then folds to the base memref + or to pto.castptr/pto.addptr on the base pointer + - pto.get_tensor_view_dim → folded to arith.constant for static view sizes, + or to the source size SSA operand for dynamic dims + - pto.get_tensor_view_stride → folded to the lowered reinterpret_cast + stride, multiplied by the subview stride when needed, or to the native + make_tensor_view stride operand + + Dead unrealized_conversion_cast, memref.subview, memref.reinterpret_cast, + pto.partition_view, and pto.make_tensor_view ops exposed by folding are + cleaned up after the rewrite. }]; let constructor = "mlir::pto::createFoldTileBufIntrinsicsPass()"; let options = [ diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 087d8d62c6..ab903af043 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -82,10 +82,10 @@ namespace { // Four kinds of operands: // Tile — from TileBufType. dtype + shape + memorySpace + config // all participate in the specialization key (SpecKey). -// View — from MemRefType (lowered PartitionTensorViewType). The element -// dtype and optional explicit layout participate in SpecKey; -// shape/strides/memorySpace remain JSON-only metadata for Python -// constraint checking and must not perturb C++ codegen caching. +// View — from TensorViewType / PartitionTensorViewType or MemRefType. +// dtype, shape, strides, memorySpace, and optional explicit layout +// participate in SpecKey because they affect template selection and +// generated DMA parameters for tload/tstore. // Vector — from builtin VectorType. The element dtype and vector shape // participate in SpecKey so helper-side schema filtering can // distinguish auxiliary vector operands such as tmrgsort's @@ -107,7 +107,7 @@ struct OperandTypeInfo { int32_t fractal = 0; uint64_t pad = 0; - // --- View-only (MemRefType) — for JSON / constraint checking only --- + // --- View-only --- SmallVector viewShape; SmallVector viewStrides; std::string viewMemorySpace; // "gm" or "ub" @@ -133,8 +133,8 @@ struct OperandTypeInfo { return vectorShape == rhs.vectorShape; if (kind == OperandKind::Scalar) return scalarValue == rhs.scalarValue; - // View: dtype + explicit layout are sufficient for template caching. - return viewLayout == rhs.viewLayout; + return viewShape == rhs.viewShape && viewStrides == rhs.viewStrides && + viewMemorySpace == rhs.viewMemorySpace && viewLayout == rhs.viewLayout; } }; @@ -178,7 +178,11 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { h = llvm::hash_combine(h, *op.scalarValue); } if (op.kind == OperandKind::View) { - h = llvm::hash_combine(h, op.viewLayout.has_value()); + h = llvm::hash_combine(h, op.viewMemorySpace, op.viewLayout.has_value()); + for (int64_t d : op.viewShape) + h = llvm::hash_combine(h, d); + for (int64_t d : op.viewStrides) + h = llvm::hash_combine(h, d); if (op.viewLayout) h = llvm::hash_combine(h, static_cast(*op.viewLayout)); } @@ -543,6 +547,18 @@ static void recordStaticSizes(ArrayRef inputs, out.push_back(getStaticIntOrDynamic(ofr)); } +static void recordStaticValues(ValueRange inputs, SmallVectorImpl &out) { + out.clear(); + out.reserve(inputs.size()); + for (Value value : inputs) { + int64_t dim = ShapedType::kDynamic; + if (getStaticIntFromValue(value, dim)) + out.push_back(dim); + else + out.push_back(ShapedType::kDynamic); + } +} + static SmallVector combineSubviewStrides(ArrayRef baseStrides, ArrayRef steps) { SmallVector result; @@ -606,6 +622,40 @@ static void populateViewShapeAndStrides(Value value, } } +static void populateTensorViewShapeAndStrides(Value value, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + if (!value) + return; + + Operation *def = value.getDefiningOp(); + if (auto partition = dyn_cast_or_null(def)) { + recordStaticValues(partition.getSizes(), shape); + SmallVector sourceShape; + populateTensorViewShapeAndStrides(partition.getSource(), sourceShape, + strides); + return; + } + + if (auto makeView = dyn_cast_or_null(def)) { + recordStaticValues(makeView.getShape(), shape); + recordStaticValues(makeView.getStrides(), strides); + return; + } + + Type ty = value.getType(); + if (auto partTy = dyn_cast(ty)) { + if (shape.empty()) + shape.assign(partTy.getShape().begin(), partTy.getShape().end()); + return; + } + if (auto tvTy = dyn_cast(ty)) { + if (shape.empty()) + shape.assign(tvTy.getShape().begin(), tvTy.getShape().end()); + return; + } +} + static std::optional buildOperandTypeInfo(Value value) { Type ty = value.getType(); // Tile operand — from TileBufType. @@ -633,6 +683,36 @@ static std::optional buildOperandTypeInfo(Value value) { return info; } + // View operand — from native TensorViewType / PartitionTensorViewType before + // PTOViewToMemref has rewritten the view to memref. + if (auto tvTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::View; + info.dtype = getDtypeString(tvTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.viewMemorySpace = "gm"; + info.viewLayout = resolveViewLayout(value); + populateTensorViewShapeAndStrides(value, info.viewShape, info.viewStrides); + if (info.viewShape.empty()) + info.viewShape.assign(tvTy.getShape().begin(), tvTy.getShape().end()); + return info; + } + + if (auto partTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::View; + info.dtype = getDtypeString(partTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.viewMemorySpace = "gm"; + info.viewLayout = resolveViewLayout(value); + populateTensorViewShapeAndStrides(value, info.viewShape, info.viewStrides); + if (info.viewShape.empty()) + info.viewShape.assign(partTy.getShape().begin(), partTy.getShape().end()); + return info; + } + // View operand — from MemRefType (lowered PartitionTensorViewType). if (auto mrTy = dyn_cast(ty)) { OperandTypeInfo info; @@ -843,6 +923,11 @@ static std::string buildUniqueFunctionBaseName(const SpecKey &key) { uniqueName += "_fr" + std::to_string(op.fractal); uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false); } else if (op.kind == OperandKind::View) { + for (int64_t d : op.viewShape) + uniqueName += "_s" + std::to_string(d); + for (int64_t d : op.viewStrides) + uniqueName += "_st" + std::to_string(d); + uniqueName += "_ms_" + op.viewMemorySpace; if (op.viewLayout) uniqueName += "_vl_" + stringifyLayout(*op.viewLayout).str(); } else if (op.kind == OperandKind::Vector) { @@ -873,6 +958,39 @@ static std::string buildContextAttrsJson(const SpecKey &key) { return json; } +static bool isViewLikeType(Type type) { + return isa(type); +} + +static void specializeTemplateEntryArgumentTypes(func::FuncOp fn, + Operation *tileOp) { + if (!fn || fn.isExternal()) + return; + + FunctionType fnTy = fn.getFunctionType(); + SmallVector inputs(fnTy.getInputs().begin(), fnTy.getInputs().end()); + bool changed = false; + unsigned operandCount = std::min(tileOp->getNumOperands(), + inputs.size()); + for (unsigned i = 0; i < operandCount; ++i) { + Type callerTy = tileOp->getOperand(i).getType(); + Type calleeTy = inputs[i]; + if (callerTy == calleeTy) + continue; + if (!isViewLikeType(callerTy) || !isViewLikeType(calleeTy)) + continue; + inputs[i] = callerTy; + fn.getArgument(i).setType(callerTy); + changed = true; + } + + if (!changed) + return; + + fn.setFunctionType(FunctionType::get(fn.getContext(), inputs, + fnTy.getResults())); +} + // ============================================================================ // Invoke Python DSL daemon RPC to generate a specialized template function. // ============================================================================ @@ -1030,6 +1148,7 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, } auto cloned = clonedFuncs.front(); + specializeTemplateEntryArgumentTypes(cloned, tileOp); if (!cloned->hasAttr("pto.tilelang.instance")) { llvm::errs() << "ExpandTileOp: warning: daemon output function @" << cloned.getSymName() @@ -1229,6 +1348,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, } auto cloned = clonedFuncs.front(); + specializeTemplateEntryArgumentTypes(cloned, tileOp); // The pto.tilelang.instance attribute should already be set by the // TileLang DSL frontend in the generated MLIR. Verify it exists. if (!cloned->hasAttr("pto.tilelang.instance")) { diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index 37cdc785d3..1badfd006e 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -25,11 +25,11 @@ // For tile_buf intrinsics, the active VPTO path folds against materialized tile // handles produced by the shared tile-handle bridge (`pto.alloc_tile` or // `pto.materialize_tile`). -// For tensor_view intrinsics, the pass traces through the full -// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast -// chain to fold directly to constants or SSA operands from the -// reinterpret_cast, without generating intermediate memref.dim / -// memref.extract_strided_metadata ops. +// For tensor_view intrinsics, the pass traces either through the lowered +// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast chain +// or through the native pto.partition_view → pto.make_tensor_view chain to fold +// directly to constants or SSA operands, without generating intermediate +// memref.dim / memref.extract_strided_metadata ops. // //===----------------------------------------------------------------------===// @@ -42,6 +42,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -90,6 +91,19 @@ static void eraseDeadAllocTileOps(func::FuncOp func) { alloc.erase(); } +static bool isDeadPTODSLSubkernelHelper(func::FuncOp func) { + if (!func->hasAttr("pto.ptodsl.subkernel_helper")) + return false; + + auto module = func->getParentOfType(); + if (!module) + return false; + + SymbolTable symbolTable(module); + auto uses = symbolTable.getSymbolUses(func, module); + return uses && uses->empty(); +} + struct TileHandleInfo { Value sourceMemref; Value addr; @@ -209,37 +223,117 @@ static MemRefType getCanonicalMemRefTypeForTileBuf(pto::TileBufType tileTy) { AffineMap(), tileTy.getMemorySpace()); } +enum class ViewChainKind { + MemRef, + Native, +}; + struct ViewChain { + ViewChainKind kind = ViewChainKind::MemRef; + + // Lowered memref view chain. UnrealizedConversionCastOp cast; memref::SubViewOp subview; memref::ReinterpretCastOp reinterpretCast; Value baseMemref; + + // Native tensor_view / partition_tensor_view chain. + pto::MakeTensorViewOp makeView; + pto::PartitionViewOp partitionView; }; +static bool validateNativeViewChain(pto::MakeTensorViewOp makeView, + pto::PartitionViewOp partitionView, + Operation *user) { + if (!makeView) { + user->emitError("FoldTileBufIntrinsics: native tensor_view must be " + "defined by pto.make_tensor_view"); + return false; + } + + size_t rank = makeView.getShape().size(); + if (makeView.getStrides().size() != rank) { + user->emitError("FoldTileBufIntrinsics: pto.make_tensor_view shape/stride " + "rank mismatch"); + return false; + } + + if (auto tvTy = dyn_cast(makeView.getResult().getType())) { + if (static_cast(tvTy.getRank()) != rank) { + user->emitError("FoldTileBufIntrinsics: pto.make_tensor_view result rank " + "does not match shape operands"); + return false; + } + } + + if (!partitionView) + return true; + + if (partitionView.getOffsets().size() != rank || + partitionView.getSizes().size() != rank) { + user->emitError("FoldTileBufIntrinsics: pto.partition_view rank must match " + "its source tensor_view rank"); + return false; + } + + if (auto partTy = dyn_cast( + partitionView.getResult().getType())) { + if (static_cast(partTy.getRank()) != rank) { + user->emitError("FoldTileBufIntrinsics: pto.partition_view result rank " + "does not match its operands"); + return false; + } + } + + return true; +} + static std::optional traceViewChain(Value tensorView, Operation *user) { - Value memrefVal; + Value view = tensorView; UnrealizedConversionCastOp castOp; - if (isa(tensorView.getType())) { - memrefVal = tensorView; - } else { - castOp = tensorView.getDefiningOp(); - if (!castOp || castOp.getNumOperands() != 1) { - user->emitError( - "FoldTileBufIntrinsics: expected tensor_view to be defined by a " - "single-operand builtin.unrealized_conversion_cast"); - return std::nullopt; + if (auto cast = view.getDefiningOp()) { + if (cast.getNumOperands() == 1 && cast.getNumResults() == 1) { + castOp = cast; + view = cast.getOperand(0); } - memrefVal = castOp.getOperand(0); - if (!isa(memrefVal.getType())) { - user->emitError( - "FoldTileBufIntrinsics: expected cast operand to be a memref, got ") - << memrefVal.getType(); - return std::nullopt; + } + + if (!isa(view.getType())) { + if (auto partition = view.getDefiningOp()) { + auto makeView = + partition.getSource().getDefiningOp(); + if (!validateNativeViewChain(makeView, partition, user)) + return std::nullopt; + ViewChain chain; + chain.kind = ViewChainKind::Native; + chain.cast = castOp; + chain.makeView = makeView; + chain.partitionView = partition; + return chain; + } + + if (auto makeView = view.getDefiningOp()) { + if (!validateNativeViewChain(makeView, pto::PartitionViewOp(), user)) + return std::nullopt; + ViewChain chain; + chain.kind = ViewChainKind::Native; + chain.cast = castOp; + chain.makeView = makeView; + return chain; } + + user->emitError("FoldTileBufIntrinsics: expected tensor_view to be defined " + "by a lowered memref.subview chain or native " + "pto.partition_view/pto.make_tensor_view chain, got ") + << (view.getDefiningOp() + ? view.getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; } + Value memrefVal = view; auto subviewOp = memrefVal.getDefiningOp(); if (!subviewOp) { user->emitError("FoldTileBufIntrinsics: expected memref to be defined by " @@ -261,7 +355,13 @@ static std::optional traceViewChain(Value tensorView, return std::nullopt; } - return ViewChain{castOp, subviewOp, rcOp, rcOp.getSource()}; + ViewChain chain; + chain.kind = ViewChainKind::MemRef; + chain.cast = castOp; + chain.subview = subviewOp; + chain.reinterpretCast = rcOp; + chain.baseMemref = rcOp.getSource(); + return chain; } static bool getConstIndexValue(Value v, int64_t &out) { @@ -299,13 +399,32 @@ static Value getValueOrCreateConstant(OpBuilder &builder, Location loc, return builder.create(loc, intAttr.getInt()); } +static bool getConstIndexValue(OpFoldResult ofr, int64_t &out) { + if (auto value = dyn_cast(ofr)) + return getConstIndexValue(value, out); + auto intAttr = dyn_cast(cast(ofr)); + if (!intAttr) + return false; + out = intAttr.getInt(); + return true; +} + +static bool isStaticIndexValue(OpFoldResult ofr, int64_t expected) { + int64_t value = 0; + return getConstIndexValue(ofr, value) && value == expected; +} + +static SmallVector valuesToFoldResults(ValueRange values) { + SmallVector result; + result.reserve(values.size()); + for (Value value : values) + result.push_back(value); + return result; +} + static bool isAllStaticZero(ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { - auto attr = dyn_cast(ofr); - if (!attr) - return false; - auto intAttr = dyn_cast(attr); - if (!intAttr || intAttr.getInt() != 0) + if (!isStaticIndexValue(ofr, 0)) return false; } return true; @@ -314,11 +433,8 @@ static bool isAllStaticZero(ArrayRef ofrs) { static Value computeResultStride(OpBuilder &builder, Location loc, OpFoldResult rcStride, OpFoldResult svStride) { - if (auto attr = dyn_cast(svStride)) { - auto intAttr = dyn_cast(attr); - if (intAttr && intAttr.getInt() == 1) - return getValueOrCreateConstant(builder, loc, rcStride); - } + if (isStaticIndexValue(svStride, 1)) + return getValueOrCreateConstant(builder, loc, rcStride); Value lhs = getValueOrCreateConstant(builder, loc, rcStride); Value rhs = getValueOrCreateConstant(builder, loc, svStride); @@ -338,11 +454,8 @@ static Value computeLinearOffset(OpBuilder &builder, Location loc, Value svPart; if (!svAllZero) { for (auto [svOffset, rcStride] : llvm::zip(svOffsets, rcStrides)) { - if (auto attr = dyn_cast(svOffset)) { - auto intAttr = dyn_cast(attr); - if (intAttr && intAttr.getInt() == 0) - continue; - } + if (isStaticIndexValue(svOffset, 0)) + continue; Value off = getValueOrCreateConstant(builder, loc, svOffset); Value stride = getValueOrCreateConstant(builder, loc, rcStride); @@ -363,6 +476,106 @@ static Value computeLinearOffset(OpBuilder &builder, Location loc, return rcPart ? rcPart : svPart; } +static int64_t getStaticNativeViewDim(ViewChain &chain, int64_t dimIdx) { + if (chain.partitionView) { + auto partTy = dyn_cast( + chain.partitionView.getResult().getType()); + if (partTy && partTy.getDimSize(dimIdx) != ShapedType::kDynamic) + return partTy.getDimSize(dimIdx); + return ShapedType::kDynamic; + } + + auto tvTy = + dyn_cast(chain.makeView.getResult().getType()); + if (tvTy && tvTy.getDimSize(dimIdx) != ShapedType::kDynamic) + return tvTy.getDimSize(dimIdx); + return ShapedType::kDynamic; +} + +static unsigned getViewRank(ViewChain &chain) { + if (chain.kind == ViewChainKind::MemRef) + return cast(chain.subview.getType()).getRank(); + return chain.makeView.getShape().size(); +} + +static std::optional buildTensorViewDimValue(OpBuilder &builder, + Location loc, + ViewChain &chain, + int64_t dimIdx, + Operation *user) { + if (chain.kind == ViewChainKind::MemRef) { + auto svTy = cast(chain.subview.getType()); + if (!svTy.isDynamicDim(dimIdx)) + return builder.create(loc, + svTy.getDimSize(dimIdx)); + return getValueOrCreateConstant(builder, loc, + chain.subview.getMixedSizes()[dimIdx]); + } + + int64_t staticDim = getStaticNativeViewDim(chain, dimIdx); + if (staticDim != ShapedType::kDynamic) + return builder.create(loc, staticDim); + + ValueRange sizes = chain.partitionView ? chain.partitionView.getSizes() + : chain.makeView.getShape(); + if (dimIdx < 0 || static_cast(dimIdx) >= sizes.size()) { + user->emitError("FoldTileBufIntrinsics: native tensor_view dim index out " + "of bounds"); + return std::nullopt; + } + return sizes[dimIdx]; +} + +static std::optional buildTensorViewStrideValue(OpBuilder &builder, + Location loc, + ViewChain &chain, + int64_t dimIdx, + Operation *user) { + if (chain.kind == ViewChainKind::MemRef) + return computeResultStride( + builder, loc, chain.reinterpretCast.getMixedStrides()[dimIdx], + chain.subview.getMixedStrides()[dimIdx]); + + ValueRange strides = chain.makeView.getStrides(); + if (dimIdx < 0 || static_cast(dimIdx) >= strides.size()) { + user->emitError("FoldTileBufIntrinsics: native tensor_view stride index " + "out of bounds"); + return std::nullopt; + } + return strides[dimIdx]; +} + +static Value computeNativeLinearOffset(OpBuilder &builder, Location loc, + ViewChain &chain) { + if (!chain.partitionView) + return Value(); + + SmallVector offsets = + valuesToFoldResults(chain.partitionView.getOffsets()); + SmallVector strides = + valuesToFoldResults(chain.makeView.getStrides()); + return computeLinearOffset(builder, loc, /*rcOffsets=*/{}, offsets, strides); +} + +static std::optional buildNativeTensorViewBasePtr(OpBuilder &builder, + Location loc, + ViewChain &chain, + pto::PtrType resultTy, + Operation *user) { + Value base = chain.makeView.getPtr(); + if (base.getType() == resultTy) + return base; + + if (!isa(base.getType())) { + user->emitError("FoldTileBufIntrinsics: native tensor_view_addr base must " + "be !pto.ptr, memref, or integer, got ") + << base.getType(); + return std::nullopt; + } + + return builder.create(loc, resultTy, base).getResult(); +} + struct FoldTileBufIntrinsicsPass : public pto::impl::FoldTileBufIntrinsicsBase { using FoldTileBufIntrinsicsBase::FoldTileBufIntrinsicsBase; @@ -380,12 +593,13 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - // Leftover TileLang template instances (private, uncalled after - // PTOInlineLibCall) still contain pto.tile_buf_addr / tile_valid_* - // ops on tile_buf function arguments — they have no materialized tile - // handle anchor to fold against and will be removed by later DCE. Skip - // them. - if (func->hasAttr("pto.tilelang.instance")) + // Leftover TileLang template instances and already-inlined PTODSL + // subkernel helpers may still contain structured-view intrinsics on + // function arguments. Those formal arguments have no materialized + // call-site handle to fold against; the live caller body has already been + // inlined and folded separately. + if (func->hasAttr("pto.tilelang.instance") || + isDeadPTODSLSubkernelHelper(func)) return; SmallVector addrOps; @@ -587,8 +801,8 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - auto svTy = cast(chain->subview.getType()); - if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + unsigned rank = getViewRank(*chain); + if (dimIdx < 0 || static_cast(dimIdx) >= rank) { dimOp.emitError( "FoldTileBufIntrinsics: get_tensor_view_dim dim index out of " "bounds"); @@ -596,17 +810,13 @@ struct FoldTileBufIntrinsicsPass } builder.setInsertionPoint(dimOp); - Value replacement; - if (!svTy.isDynamicDim(dimIdx)) { - replacement = - builder.create(dimOp.getLoc(), - svTy.getDimSize(dimIdx)); - } else { - replacement = getValueOrCreateConstant( - builder, dimOp.getLoc(), chain->subview.getMixedSizes()[dimIdx]); - } + std::optional replacement = + buildTensorViewDimValue(builder, dimOp.getLoc(), *chain, dimIdx, + dimOp.getOperation()); + if (!replacement) + return signalPassFailure(); - dimOp.getResult().replaceAllUsesWith(replacement); + dimOp.getResult().replaceAllUsesWith(*replacement); dimOp.erase(); } @@ -623,8 +833,8 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - auto svTy = cast(chain->subview.getType()); - if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + unsigned rank = getViewRank(*chain); + if (dimIdx < 0 || static_cast(dimIdx) >= rank) { strideOp.emitError( "FoldTileBufIntrinsics: get_tensor_view_stride dim index out of " "bounds"); @@ -632,12 +842,13 @@ struct FoldTileBufIntrinsicsPass } builder.setInsertionPoint(strideOp); - Value replacement = computeResultStride( - builder, strideOp.getLoc(), - chain->reinterpretCast.getMixedStrides()[dimIdx], - chain->subview.getMixedStrides()[dimIdx]); + std::optional replacement = buildTensorViewStrideValue( + builder, strideOp.getLoc(), *chain, dimIdx, + strideOp.getOperation()); + if (!replacement) + return signalPassFailure(); - strideOp.getResult().replaceAllUsesWith(replacement); + strideOp.getResult().replaceAllUsesWith(*replacement); strideOp.erase(); } } @@ -654,6 +865,12 @@ struct FoldTileBufIntrinsicsPass if (!resultPtrType) { if (auto resultMemrefType = dyn_cast(addrOp.getDst().getType())) { + if (chain->kind == ViewChainKind::Native) { + addrOp.emitError("FoldTileBufIntrinsics: native tensor_view_addr " + "cannot fold to memref without first lowering " + "the view to memref"); + return signalPassFailure(); + } Value base = chain->baseMemref; if (base.getType() != resultMemrefType) addrOp.getDst().setType(cast(base.getType())); @@ -667,14 +884,27 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - Value linearOffset = - computeLinearOffset(builder, addrOp.getLoc(), - chain->reinterpretCast.getMixedOffsets(), - chain->subview.getMixedOffsets(), - chain->reinterpretCast.getMixedStrides()); + Value linearOffset; + Value basePtr; + if (chain->kind == ViewChainKind::MemRef) { + linearOffset = + computeLinearOffset(builder, addrOp.getLoc(), + chain->reinterpretCast.getMixedOffsets(), + chain->subview.getMixedOffsets(), + chain->reinterpretCast.getMixedStrides()); + basePtr = builder.create( + addrOp.getLoc(), resultPtrType, chain->baseMemref); + } else { + std::optional nativeBase = buildNativeTensorViewBasePtr( + builder, addrOp.getLoc(), *chain, resultPtrType, + addrOp.getOperation()); + if (!nativeBase) + return signalPassFailure(); + basePtr = *nativeBase; + linearOffset = computeNativeLinearOffset(builder, addrOp.getLoc(), + *chain); + } - Value basePtr = builder.create( - addrOp.getLoc(), resultPtrType, chain->baseMemref); Value replacement = linearOffset ? builder.create(addrOp.getLoc(), resultPtrType, @@ -714,6 +944,20 @@ struct FoldTileBufIntrinsicsPass op->erase(); } + while (true) { + SmallVector deadViewOps; + func.walk([&](Operation *op) { + if ((isa(op) || + isa(op)) && + op->use_empty()) + deadViewOps.push_back(op); + }); + if (deadViewOps.empty()) + break; + for (auto *op : llvm::reverse(deadViewOps)) + op->erase(); + } + eraseDeadAllocTileOps(func); } }; diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp index c278bd196a..6ab5c4f605 100644 --- a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -347,6 +347,22 @@ static void eraseDeadMatchingPrivateFuncs(ModuleOp module, } } +static void eraseDeadPTODSLSubkernelHelpers(ModuleOp module) { + for (ModuleOp funcModule : collectFuncModules(module)) { + SymbolTable symbolTable(funcModule); + SmallVector deadFuncs; + for (func::FuncOp func : funcModule.getOps()) { + if (!isPTODSLSubkernelHelperFunc(func)) + continue; + auto uses = symbolTable.getSymbolUses(func, funcModule); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + } +} + struct PTOInlineBackendHelpersPass : public pto::impl::PTOInlineBackendHelpersBase< PTOInlineBackendHelpersPass> { @@ -371,7 +387,7 @@ struct PTOInlineBackendHelpersPass << " call(s)\n"; } - eraseDeadMatchingPrivateFuncs(module, isInlineableBackendHelperFunc); + eraseDeadPTODSLSubkernelHelpers(module); } }; diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 12eba5ce58..7435052cd2 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -208,8 +208,10 @@ def my_kernel( barriers by hand, and work with raw pointers — useful when you need to hand-tune instruction schedules or overlap DMA with compute. -`mode` only affects what you can write inside the function body. It doesn't -change how you compile or launch the kernel. +For native launch builds, `mode` also selects the default PTOAS build policy: +`mode="auto"` keeps the PTOAS default build level and enables sync insertion, +while `mode="explicit"` uses `--pto-level=level3` and leaves synchronization +under user control by default. #### `backend`: VPTO vs EmitC diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 61e4b8e50a..a7a566624d 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -50,8 +50,10 @@ The **`backend`** parameter selects the compilation target: rejected at decoration time with an actionable diagnostic. The **`mode`** parameter selects the programming model within the kernel body -(see Section 3.4). `mode` only affects what you can write inside the function — -it doesn't change how you compile or launch the kernel. +(see Section 3.4). For native launch builds, `mode="auto"` keeps PTOAS default +build level and enables sync insertion by default, while `mode="explicit"` uses +`--pto-level=level3` and disables sync insertion by default. This matches the +manual-address, user-managed staging contract of explicit kernels. `@pto.jit` owns compilation (tracing + lowering), caching, and — for `entry=True` — runtime launch binding. The compute-unit decorators diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index 572b1d5950..92226e000d 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -341,6 +341,15 @@ def illegal_inline_subkernel_placement_error(role: str, outer_role: str | None) ) +def subkernel_kernel_kind_mismatch_error(role: str, kernel_kind: str) -> RuntimeError: + """Return one diagnostic for mixing explicit @pto.jit kernel kind with the opposite subkernel kind.""" + return RuntimeError( + f"@pto.{role} cannot be lowered inside an explicit @pto.jit(kernel_kind={kernel_kind!r}) " + "module. Remove the explicit kernel_kind so PTOAS can split cube/vector sections, " + "or keep subkernel scopes in the same physical kind." + ) + + def inline_subkernel_value_escape_error(role: str, type_text: str) -> RuntimeError: """Return one diagnostic for outlined inline-scope values escaping their helper boundary.""" return RuntimeError( @@ -503,6 +512,7 @@ def unsupported_public_surface_error(name: str) -> AttributeError: "subkernel_host_tensor_boundary_error", "subkernel_illegal_annotation_error", "subkernel_illegal_parameter_kind_error", + "subkernel_kernel_kind_mismatch_error", "subkernel_missing_annotation_error", "subkernel_signature_boundary_error", "tile_row_alignment_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index fc86a21b9f..26fb8311e5 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -34,6 +34,15 @@ _MODULE_ATTRS = ("pto.target_arch",) _SUPPORTED_FRONTEND_OPTION_KEYS = {"ast_rewrite", "rewrite_part", "dump_rewritten_source"} _SUPPORTED_REWRITE_PARTS = {"control_flow"} +_DEFAULT_KERNEL_KIND = "vector" + + +class _DefaultKernelKindSentinel: + def __repr__(self) -> str: + return repr(_DEFAULT_KERNEL_KIND) + + +_DEFAULT_KERNEL_KIND_SENTINEL = _DefaultKernelKindSentinel() def _normalize_mode(mode: str, *, fn=None) -> str: @@ -164,7 +173,7 @@ def jit( name=None, *, target: str = "a5", - kernel_kind: str = "vector", + kernel_kind: str = _DEFAULT_KERNEL_KIND_SENTINEL, backend: str = "vpto", entry: bool = True, mode: str = "auto", @@ -180,10 +189,10 @@ def jit( ---------- name: IR function name (defaults to the Python function name). target: Target architecture string, e.g. ``"a5"``. - kernel_kind: authored default physical kind, used for native build selection - and VPTO authoring intent. PTODSL now expresses physical regions - through ``pto.section.vector/cube`` instead of child-module - ``pto.kernel_kind`` attributes. + kernel_kind: optional authored physical kind, used for native build selection + and explicit single-kind VPTO authoring intent. When omitted, + PTODSL keeps the historical vector default while allowing + subkernel sections to express mixed cube/vector regions. backend: ``"vpto"`` or ``"emitc"`` – records the intended backend. entry: ``True`` for launchable kernel entries, ``False`` for helpers. mode: ``"auto"`` or ``"explicit"`` – feeds child compile policy. @@ -238,12 +247,15 @@ def decorator(fn): source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) except (OSError, TypeError): source_file = None + kernel_kind_explicit = kernel_kind is not _DEFAULT_KERNEL_KIND_SENTINEL + effective_kernel_kind = kernel_kind if kernel_kind_explicit else _DEFAULT_KERNEL_KIND compiler = KernelCompiler( fn.__name__, KernelModuleSpec( function_name=fn_name, target_arch=target, - kernel_kind=kernel_kind, + kernel_kind=effective_kernel_kind, + kernel_kind_explicit=kernel_kind_explicit, backend=normalized_backend, entry=entry, mode=normalized_mode, @@ -307,6 +319,10 @@ def __ptodsl_cache_signature__(self): self._compiler._kernel_identity, module_spec.function_name, module_spec.entry, + module_spec.backend, + module_spec.mode, + module_spec.kernel_kind, + module_spec.kernel_kind_explicit, ) def _build_default_module(self): diff --git a/ptodsl/ptodsl/_runtime/cache.py b/ptodsl/ptodsl/_runtime/cache.py index 6231c16b02..9552399401 100644 --- a/ptodsl/ptodsl/_runtime/cache.py +++ b/ptodsl/ptodsl/_runtime/cache.py @@ -67,6 +67,7 @@ def write_manifest( launch_symbol: str, mlir_digest: str, launch_cpp_digest: str, + compile_config_digest: str, link_config_digest: str, ) -> None: artifacts.cache_dir.mkdir(parents=True, exist_ok=True) @@ -76,6 +77,7 @@ def write_manifest( "shared_library": str(artifacts.shared_library), "mlir_digest": mlir_digest, "launch_cpp_digest": launch_cpp_digest, + "compile_config_digest": compile_config_digest, "link_config_digest": link_config_digest, } artifacts.manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8") @@ -90,6 +92,7 @@ def is_native_build_current( *, mlir_text: str, launch_cpp_text: str, + compile_config_text: str, link_config_text: str, ) -> bool: required = ( @@ -110,6 +113,7 @@ def is_native_build_current( return ( manifest.get("mlir_digest") == _content_digest(mlir_text) and manifest.get("launch_cpp_digest") == _content_digest(launch_cpp_text) + and manifest.get("compile_config_digest") == _content_digest(compile_config_text) and manifest.get("link_config_digest") == _content_digest(link_config_text) ) diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index 0821c326be..410a3997ae 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -76,13 +76,34 @@ def _effective_insert_sync(*, mode: str, insert_sync: bool | None) -> bool: return mode != "explicit" +def _effective_pto_level(*, mode: str) -> str | None: + return "level3" if mode == "explicit" else None + + def _source_ptoas_overrides(module_spec) -> dict: if getattr(module_spec, "jit_source", None) is None: return {} - overrides = {"backend": module_spec.backend} - if module_spec.mode == "explicit": - overrides["pto_level"] = "level3" - return overrides + return {"backend": module_spec.backend} + + +def _compile_config_text( + *, + module_spec, + effective_insert_sync: bool, + effective_pto_level: str | None, + ptoas_overrides: dict, +) -> str: + return "\n".join( + [ + f"target_arch={module_spec.target_arch}", + f"kernel_kind={module_spec.kernel_kind}", + f"mode={module_spec.mode}", + f"insert_sync={effective_insert_sync}", + f"pto_level={effective_pto_level}", + f"backend={ptoas_overrides.get('backend')}", + "enable_tile_op_expand=True", + ] + ) def _host_compile_flags() -> list[str]: @@ -191,6 +212,18 @@ def build_native_library( ir_function_name=ir_function_name, kernel_signature=kernel_signature, ) + effective_insert_sync = _effective_insert_sync( + mode=module_spec.mode, + insert_sync=module_spec.insert_sync, + ) + effective_pto_level = _effective_pto_level(mode=module_spec.mode) + ptoas_overrides = _source_ptoas_overrides(module_spec) + compile_config_text = _compile_config_text( + module_spec=module_spec, + effective_insert_sync=effective_insert_sync, + effective_pto_level=effective_pto_level, + ptoas_overrides=ptoas_overrides, + ) sim_mode = bool(os.environ.get("MSPROF_SIMULATOR_MODE")) link_config_text = "\n".join(runtime_library_flags(sim_mode=sim_mode)) @@ -198,6 +231,7 @@ def build_native_library( artifacts, mlir_text=mlir_text, launch_cpp_text=launch_cpp_text, + compile_config_text=compile_config_text, link_config_text=link_config_text, ): return artifacts.shared_library, launch_symbol @@ -210,11 +244,9 @@ def build_native_library( artifacts.mlir_path, artifacts.kernel_object, target_arch=module_spec.target_arch, - insert_sync=_effective_insert_sync( - mode=module_spec.mode, - insert_sync=module_spec.insert_sync, - ), - **_source_ptoas_overrides(module_spec), + insert_sync=effective_insert_sync, + pto_level=effective_pto_level, + **ptoas_overrides, ) launch_object = artifacts.cache_dir / "launch.o" @@ -237,6 +269,7 @@ def build_native_library( launch_symbol=launch_symbol, mlir_digest=_content_digest(mlir_text), launch_cpp_digest=_content_digest(launch_cpp_text), + compile_config_digest=_content_digest(compile_config_text), link_config_digest=_content_digest(link_config_text), ) return artifacts.shared_library, launch_symbol diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index f4108724c5..1625b07e9b 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -31,6 +31,7 @@ class KernelModuleSpec: function_name: str target_arch: str kernel_kind: str + kernel_kind_explicit: bool = False backend: str = "vpto" entry: bool = True mode: str = "auto" diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index c2390f1d63..8f81836996 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -13,7 +13,7 @@ from dataclasses import dataclass import hashlib -from .._diagnostics import inline_subkernel_value_escape_error +from .._diagnostics import inline_subkernel_value_escape_error, subkernel_kernel_kind_mismatch_error from .._kernel_signature import RuntimeScalarParameterSpec from .._ops import const from .._surface_values import unwrap_surface_value, wrap_like_surface_value @@ -239,16 +239,50 @@ def _create_subkernel_section_op(self, role: str): return None def _create_inline_subkernel_wrapper(self, role: str): - wrapper_op = self._create_subkernel_section_op(role) + wrapper_op = None + if self._subkernel_section_policy(role) != "function_kind": + wrapper_op = self._create_subkernel_section_op(role) if wrapper_op is None: wrapper_op = _pto.VecScopeOp() body_block = wrapper_op.body.blocks.append() return wrapper_op, body_block + def _subkernel_role_kernel_kind(self, role: str) -> str | None: + if role == "simd": + return "vector" + if role == "cube": + return "cube" + return None + + def _current_explicit_kernel_kind(self) -> str | None: + module_spec = self.current_function_module_spec + if not getattr(module_spec, "kernel_kind_explicit", False): + return None + kind = getattr(module_spec, "kernel_kind", None) + return kind if kind in {"cube", "vector"} else None + + def _subkernel_section_policy(self, role: str) -> str: + role_kind = self._subkernel_role_kernel_kind(role) + explicit_kind = self._current_explicit_kernel_kind() + if role_kind is None or explicit_kind is None: + return "section" + if explicit_kind != role_kind: + raise subkernel_kernel_kind_mismatch_error(role, explicit_kind) + return "function_kind" + def _subkernel_helper_attributes(self, role: str) -> tuple[tuple[str, object], ...]: attrs: list[tuple[str, object]] = [] if role in {"simd", "cube"}: attrs.append(("pto.ptodsl.subkernel_helper", StringAttr.get(role))) + if self._subkernel_section_policy(role) == "function_kind": + attrs.append( + ( + "pto.kernel_kind", + Attribute.parse( + f"#pto.kernel_kind<{self._subkernel_role_kernel_kind(role)}>" + ), + ) + ) if role == "simt": attrs.append(("pto.simt_entry", UnitAttr.get())) return tuple(attrs) @@ -275,6 +309,10 @@ def enter_subkernel_body(self, role: str, symbol_name: str, target: str): ) self._subkernel_stack.append(frame) try: + if self._subkernel_section_policy(role) == "function_kind": + yield frame + return + section_op = self._create_subkernel_section_op(role) if section_op is None: yield frame @@ -391,7 +429,8 @@ def _remap_captured_operands(self, root_ops, capture_mapping) -> None: def _outline_inline_subkernel(self, outline_frame: InlineSubkernelOutlineFrame) -> None: role = outline_frame.trace_frame.role - if role in {"simd", "cube"}: + section_policy = self._subkernel_section_policy(role) + if role in {"simd", "cube"} and section_policy != "function_kind": root_ops = (outline_frame.wrapper_op,) else: root_ops = tuple(outline_frame.body_block.operations) @@ -422,7 +461,7 @@ def _outline_inline_subkernel(self, outline_frame: InlineSubkernelOutlineFrame) terminator = func.ReturnOp([]) return_anchor = terminator.operation.opview - if role in {"simd", "cube"}: + if role in {"simd", "cube"} and section_policy != "function_kind": outline_frame.wrapper_op.move_before(return_anchor) outlined_roots = (outline_frame.wrapper_op,) else: diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index c7ed86ead9..210fb43330 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -151,6 +151,25 @@ def host_vec_copy_explicit( pto.tile.store(o_tile, out) +@pto.jit(target="a5", mode="explicit") +def host_vec_copy_explicit_addr( + A_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), + rows: pto.i32, + cols: pto.i32, + *, + BLOCK: pto.const_expr = 128, +): + a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) + o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, addr=0) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, addr=4096) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + @pto.jit(target="a5", backend="emitc") def host_vec_copy_emitc( A_ptr: pto.ptr(pto.f32, "gm"), @@ -670,6 +689,16 @@ def top_level_simd_probe(): SUBKERNEL_OBSERVATIONS.append((frame.role, frame.symbol_name, session.subkernel_stack_depth)) +@pto.simd +def explicit_vector_simd_probe(): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.cube +def explicit_vector_cube_probe(): + pto.pipe_barrier(pto.Pipe.ALL) + + @pto.jit(target="a5") def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): top_level_cube_probe() @@ -677,6 +706,22 @@ def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): nested_simd_probe() +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_calls_simd_probe(*, TRACE_TOKEN: pto.const_expr = 0): + explicit_vector_simd_probe() + + +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_calls_cube_probe(*, TRACE_TOKEN: pto.const_expr = 0): + explicit_vector_cube_probe() + + +@pto.jit(target="a5", kernel_kind="vector") +def explicit_vector_inline_simd_probe(*, TRACE_TOKEN: pto.const_expr = 0): + with pto.simd(): + pto.pipe_barrier(pto.Pipe.ALL) + + @pto.jit(target="a5", mode="explicit") def inline_subkernel_scope_probe(*, TRACE_TOKEN: pto.const_expr = 0): session = current_session() @@ -3463,6 +3508,10 @@ def inline_source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): and helper_cache_signature[4] is False, "@pto.jit(entry=False) handles should expose an explicit, stable cache-signature protocol", ) + expect( + helper_cache_signature[7] == "vector" and helper_cache_signature[8] is False, + "default @pto.jit handles should keep vector as the effective kernel kind while recording that it was not explicit", + ) expect_raises( RuntimeError, kernel_module_return_probe.compile, @@ -3662,6 +3711,7 @@ def inline_source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): ) native_build_variants = ( ("pure-container", host_vec_copy.compile()), + ("explicit-level3-container", host_vec_copy_explicit_addr.compile()), ("same-backend-multi-child-container", kernel_module_compiled), ("mixed-backend-container", emitc_entry_calls_vpto_kernel_module_probe.compile()), ("source-auto", source_native_build_compiled), @@ -3750,14 +3800,14 @@ def fake_link_shared_library(launch_object, kernel_object, shared_library, *, ke f"{label} native build should forward the effective insert_sync policy to ptoas", ) expected_backend = compiled._module_spec.backend if compiled._module_spec.jit_source is not None else None - expected_pto_level = "level3" if compiled._module_spec.jit_source is not None and compiled._module_spec.mode == "explicit" else None + expected_pto_level = "level3" if compiled._module_spec.mode == "explicit" else None expect( observation["backend"] == expected_backend, f"{label} native build should only forward ptoas backend overrides for source-backed kernels", ) expect( observation["pto_level"] == expected_pto_level, - f"{label} native build should only map explicit mode to ptoas level3 for source-backed kernels", + f"{label} native build should derive the PTOAS level from the authored mode", ) expect( observation["mlir_text"] == compiled.mlir_text(), @@ -3799,7 +3849,7 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): ) expect( "--pto-level=level3" not in ptoas_cmd, - "native build should no longer reconstruct explicit mode through a global pto-level flag", + "native build should not pass a global pto-level flag by default", ) expect( "--enable-insert-sync" not in ptoas_cmd, @@ -3833,21 +3883,21 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): kernel_object, target_arch="a5", backend="vpto", - pto_level="level3", + pto_level=native_build_runtime._effective_pto_level(mode="explicit"), insert_sync=True, ) - expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command with source-backed overrides") - source_ptoas_cmd = ptoas_cmds[0] + expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command with explicit-mode PTOAS policy") + explicit_ptoas_cmd = ptoas_cmds[0] expect( - "--pto-backend=vpto" in source_ptoas_cmd, + "--pto-backend=vpto" in explicit_ptoas_cmd, "source-backed native build should pass the decorator backend to ptoas", ) expect( - "--pto-level=level3" in source_ptoas_cmd, - "source-backed explicit mode should pass --pto-level=level3 to ptoas", + "--pto-level=level3" in explicit_ptoas_cmd, + 'native build should pass --pto-level=level3 for mode="explicit"', ) expect( - "--enable-insert-sync" in source_ptoas_cmd, + "--enable-insert-sync" in explicit_ptoas_cmd, "source-backed native build should still pass explicit/effective insert-sync to ptoas", ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") @@ -4083,6 +4133,34 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "outlined decorated helper bodies should still preserve their PTO unit sections", ) + explicit_vector_simd_text = explicit_vector_calls_simd_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify( + explicit_vector_simd_text, + "explicit vector jit calling simd subkernel specialization", + ) + expect( + "pto.kernel_kind = #pto.kernel_kind" in explicit_vector_simd_text + and "pto.section.vector {" not in explicit_vector_simd_text, + "same-kind @pto.simd helpers inside explicit vector kernels should use function/kernel kind metadata without redundant sections", + ) + expect_raises( + RuntimeError, + lambda: explicit_vector_calls_cube_probe.compile(TRACE_TOKEN=1).mlir_text(), + "@pto.cube cannot be lowered inside an explicit @pto.jit(kernel_kind='vector')", + ) + explicit_vector_inline_simd_text = explicit_vector_inline_simd_probe.compile( + TRACE_TOKEN=1 + ).mlir_text() + expect_parse_roundtrip_and_verify( + explicit_vector_inline_simd_text, + "explicit vector jit calling inline simd specialization", + ) + expect( + "pto.kernel_kind = #pto.kernel_kind" in explicit_vector_inline_simd_text + and "pto.section.vector {" not in explicit_vector_inline_simd_text, + "same-kind inline pto.simd() scopes inside explicit vector kernels should avoid redundant sections", + ) + INLINE_SUBKERNEL_SCOPE_OBSERVATIONS.clear() inline_subkernel_scope_text = inline_subkernel_scope_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(inline_subkernel_scope_text, "inline subkernel scope specialization") diff --git a/test/dsl-st/cube_matrix_pipeline.py b/test/dsl-st/cube_matrix_pipeline.py index 420d235e75..0138fae355 100644 --- a/test/dsl-st/cube_matrix_pipeline.py +++ b/test/dsl-st/cube_matrix_pipeline.py @@ -21,50 +21,19 @@ M = 16 K = 32 -N = 48 +N = 64 +ELEM_BYTES = 4 L1_A_ADDR = 0 L1_B_ADDR = 4096 -UB_O_ADDR = 0 L0A_ADDR = 0 L0B_ADDR = 0 L0C_ADDR = 0 -@pto.cube -def cube_gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc): - m = a_mat.valid_shape[0] - k = a_mat.valid_shape[1] - n = b_mat.valid_shape[1] - - pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) - pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n) - pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) - pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) - pto.mad( - a_l0a.as_ptr(), - b_l0b.as_ptr(), - o_acc.as_ptr(), - m, - n, - k, - unit_flag=pto.MadUnitFlagMode.CHECK_ONLY, - sat=pto.SatMode.OFF, - ) - pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) - pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) - pto.mte_l0c_ub( - o_acc.as_ptr(), - o_tile.as_ptr(), - m, - n, - n, - n, - ) - - @pto.jit( name="cube_matrix_pipeline_kernel", + kernel_kind="cube", target="a5", mode="explicit", insert_sync=False, @@ -74,20 +43,14 @@ def cube_matrix_pipeline_kernel( b_ptr: pto.ptr(pto.f32, "gm"), o_ptr: pto.ptr(pto.f32, "gm"), ): - a_view = pto.make_tensor_view(a_ptr, shape=[M, K], strides=[K, 1]) - b_view = pto.make_tensor_view(b_ptr, shape=[K, N], strides=[N, 1]) - o_view = pto.make_tensor_view(o_ptr, shape=[M, N], strides=[N, 1]) - - a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[M, K]) - b_part = pto.partition_view(b_view, offsets=[0, 0], sizes=[K, N]) - o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[M, N]) - a_mat = pto.alloc_tile( shape=[M, K], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, addr=L1_A_ADDR, valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", ) b_mat = pto.alloc_tile( shape=[K, N], @@ -95,12 +58,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.MAT, addr=L1_B_ADDR, valid_shape=[K, N], - ) - o_tile = pto.alloc_tile( - shape=[M, N], - dtype=pto.f32, - addr=UB_O_ADDR, - valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", ) a_l0a = pto.alloc_tile( shape=[M, K], @@ -108,6 +67,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.LEFT, addr=L0A_ADDR, valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", ) b_l0b = pto.alloc_tile( shape=[K, N], @@ -115,6 +76,8 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.RIGHT, addr=L0B_ADDR, valid_shape=[K, N], + blayout="RowMajor", + slayout="ColMajor", ) o_acc = pto.alloc_tile( shape=[M, N], @@ -122,16 +85,58 @@ def cube_matrix_pipeline_kernel( memory_space=pto.MemorySpace.ACC, addr=L0C_ADDR, valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", + fractal_size=1024, ) - pto.tile.load(a_part, a_mat) - pto.tile.load(b_part, b_mat) + a_l1_ptr = pto.castptr(pto.ui64(L1_A_ADDR), pto.ptr(pto.f32, "mat")) + b_l1_ptr = pto.castptr(pto.ui64(L1_B_ADDR), pto.ptr(pto.f32, "mat")) + + pto.mte_gm_l1_frac( + a_ptr, + a_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(M, K), + src_layout=(K * ELEM_BYTES,), + dst_group=(1, 1, M, 0), + ctrl=(0, False), + ) pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) - cube_gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc) - pto.set_flag(pto.Pipe.FIX, pto.Pipe.MTE3, event_id=2) - pto.wait_flag(pto.Pipe.FIX, pto.Pipe.MTE3, event_id=2) - pto.tile.store(o_tile, o_part) + pto.mte_l1_l0a(a_l1_ptr, a_l0a.as_ptr(), M, K) + + pto.mte_gm_l1_frac( + b_ptr, + b_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(K, N), + src_layout=(N * ELEM_BYTES,), + dst_group=(1, 1, K, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.mte_l1_l0b(b_l1_ptr, b_l0b.as_ptr(), K, N, transpose=True) + + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.tile.matmul(a_l0a, b_l0b, o_acc) + + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.mte_l0c_gm( + o_acc.as_ptr(), + o_ptr, + M, + N, + M, + N, + 0, + 0, + layout="nz2nd", + ) + pto.pipe_barrier(pto.Pipe.ALL) def make_inputs(): diff --git a/test/dsl-st/gemv_mx_pipeline.py b/test/dsl-st/gemv_mx_pipeline.py index 89b1d80f7c..3d8b0f5f56 100644 --- a/test/dsl-st/gemv_mx_pipeline.py +++ b/test/dsl-st/gemv_mx_pipeline.py @@ -17,8 +17,6 @@ from common import assert_close, auto_main from ptodsl import pto -from ptodsl._surface_values import unwrap_surface_value -from mlir.dialects import pto as _pto M = 1 @@ -177,17 +175,6 @@ def _alloc_common_tiles(): return lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile -def _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile): - _pto.TGetScaleAddrOp( - unwrap_surface_value(lhs_tile), - unwrap_surface_value(lhs_scale_tile), - ) - _pto.TGetScaleAddrOp( - unwrap_surface_value(rhs_tile), - unwrap_surface_value(rhs_scale_tile), - ) - - def _alloc_bias_tile(): return pto.alloc_tile( shape=[M, N_STORAGE], @@ -271,7 +258,6 @@ def gemv_mx_fp8_pipeline_kernel( ): lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile = _alloc_common_tiles() _stage_fp8_tiles(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, lhs_tile, rhs_tile) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) _writeback_output(dst_tile, out_ptr) @@ -292,7 +278,6 @@ def gemv_mx_acc_fp8_pipeline_kernel( ): lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile = _alloc_common_tiles() _stage_fp8_tiles(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, lhs_tile, rhs_tile) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) pto.tile.gemv_mx_acc(dst_tile, lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, dst_tile) _writeback_output(dst_tile, out_ptr) @@ -325,7 +310,6 @@ def gemv_mx_bias_fp8_pipeline_kernel( bias_ptr=bias_ptr, bias_tile=bias_tile, ) - _bind_mx_scale_tiles(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile) pto.tile.gemv_mx_bias(lhs_tile, lhs_scale_tile, rhs_tile, rhs_scale_tile, bias_tile, dst_tile) _writeback_output(dst_tile, out_ptr) diff --git a/test/dsl-st/npu_a5/__main__.py b/test/dsl-st/npu_a5/__main__.py new file mode 100644 index 0000000000..03ec12f0ab --- /dev/null +++ b/test/dsl-st/npu_a5/__main__.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Directory runner for the npu_a5 operator-level ST cases. +# +# Each operator (e.g. tadd, tmatmul, ...) is a single *.py file that authors its +# kernel with PTODSL and builds its CASES list through the helpers in +# ``test/dsl-st/common.py``. Running this directory discovers every *.py module +# and executes the cases against the torch_npu / simulator runtime. +# +# See test/dsl-st/README.md for the authoring conventions shared with the rest +# of the dsl-st suite. + +from pathlib import Path +import sys + + +if __package__ in {None, ""}: + # common.py lives one level up, in test/dsl-st/. + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import run_discovered_cases + + +if __name__ == "__main__": + raise SystemExit(run_discovered_cases(Path(__file__).resolve().parent)) diff --git a/test/dsl-st/npu_a5/tadd.py b/test/dsl-st/npu_a5/tadd.py new file mode 100644 index 0000000000..7df2bfaeaa --- /dev/null +++ b/test/dsl-st/npu_a5/tadd.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# PTODSL rewrite of test/tilelang_st/npu/a5/src/st/testcase/tadd. +# +# Keep the kernel body close to the original semantic tile-op authoring: +# tload(a) + tload(b) + tadd(a,b)->c + tstore(c) +# +# The case uses explicit UB addresses to match the hand-authored ST contract. +# PTODSL native build derives PTOAS level3 from mode="explicit". + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +# Each case is (name, shape). Both use fully-valid f32 tiles, matching the +# original tadd cases "f32_16x64" and "f32_32x32". +CASE_SHAPES = [ + ("f32_16x64", (16, 64)), + ("f32_32x32", (32, 32)), +] + +A_TILE_ADDR = 0 +B_TILE_ADDR = 4096 +C_TILE_ADDR = 8192 + + +class _FlatKernelHandle: + """Small test-local wrapper that compiles through a flat PTODSL container.""" + + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + """Mirror @pto.jit for this testcase, but force a flat module container.""" + + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +def _tadd_body(a_ptr, b_ptr, c_ptr, *, rows, cols): + """Shared kernel body for the two tadd cases.""" + + # Keep the original 5D tilelang-style partition schema here. It matches the + # hand-authored tadd.pto layout and is already known-good for vec tile-op + # ST cases in this repository. + total = rows * cols + a_view = pto.make_tensor_view(a_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + b_view = pto.make_tensor_view(b_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + c_view = pto.make_tensor_view(c_ptr, shape=[1, 1, 1, rows, cols], strides=[total, total, total, cols, 1]) + + a_part = pto.partition_view(a_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + b_part = pto.partition_view(b_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + c_part = pto.partition_view(c_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + + # Use explicit UB addresses so direct level3 VPTO lowering has no memory + # planning dependency. + a_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=A_TILE_ADDR) + b_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=B_TILE_ADDR) + c_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.f32, addr=C_TILE_ADDR) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.tile.add(a_tile, b_tile, c_tile) + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.tile.store(c_tile, c_part) + pto.pipe_barrier(pto.Pipe.ALL) + + +# One decorated kernel per case, each binding a static shape at definition time +# (mirroring the per-case funcs in tadd.pto). +_tadd_kernels = {} +for _name, _shape in CASE_SHAPES: + _r, _c = _shape + + def _make(r=_r, c=_c): + @_flat_jit( + name=f"tadd_{_name}", + kernel_kind="vector", + target="a5", + ) + def _kernel( + a_ptr: pto.ptr(pto.f32, "gm"), + b_ptr: pto.ptr(pto.f32, "gm"), + c_ptr: pto.ptr(pto.f32, "gm"), + ): + _tadd_body(a_ptr, b_ptr, c_ptr, rows=r, cols=c) + + return _kernel + + _tadd_kernels[_name] = _make() + + +def _make_inputs(name, shape): + # Deterministic per-case seed, mirroring st_common.setup_case_rng which uses + # crc32(name). Original value range was randint(1, 10). + import zlib + np.random.seed(zlib.crc32(name.encode("utf-8")) & 0xFFFFFFFF) + a = np.random.randint(1, 10, size=shape).astype(np.float32) + b = np.random.randint(1, 10, size=shape).astype(np.float32) + return [a, b] + + +def _make_expected(a, b): + return (a + b).astype(np.float32) + + +CASES = [] +for _name, _shape in CASE_SHAPES: + CASES.append( + golden_output_case( + "tadd_" + _name, + _tadd_kernels[_name], + inputs=lambda _name=_name, _shape=_shape: _make_inputs(_name, _shape), + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ) + ) + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(*[kernel.compile() for kernel in _tadd_kernels.values()]) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tcolexpand.py b/test/dsl-st/npu_a5/tcolexpand.py new file mode 100644 index 0000000000..4d30dd5db4 --- /dev/null +++ b/test/dsl-st/npu_a5/tcolexpand.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL broadcast pilot for A5: +# tload(src) + tcolexpand(src)->dst + tstore(dst) + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +SRC_ROWS = 1 +DST_ROWS = 8 +COLS = 128 +SRC_TILE_ADDR = 0 +DST_TILE_ADDR = 4096 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +@_flat_jit( + name="tcolexpand_f32_1x8x128", + kernel_kind="vector", + target="a5", +) +def _tcolexpand_kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), +): + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, SRC_ROWS, COLS], + strides=[COLS, COLS, COLS, COLS, 1], + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, DST_ROWS, COLS], + strides=[DST_ROWS * COLS, DST_ROWS * COLS, DST_ROWS * COLS, COLS, 1], + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, SRC_ROWS, COLS]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, DST_ROWS, COLS]) + + src_tile = pto.alloc_tile(shape=[SRC_ROWS, COLS], dtype=pto.f32, addr=SRC_TILE_ADDR) + dst_tile = pto.alloc_tile(shape=[DST_ROWS, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) + + pto.tile.load(src_part, src_tile) + pto.tile.colexpand(src_tile, dst_tile) + pto.tile.store(dst_tile, dst_part) + + +def _make_input(): + rng = np.random.default_rng(0xC01E0A5) + return rng.uniform(-2.0, 2.0, size=(SRC_ROWS, COLS)).astype(np.float32) + + +def _make_expected(src): + return np.repeat(src, DST_ROWS, axis=0).astype(np.float32) + + +CASES = [ + golden_output_case( + "tcolexpand_f32_1x8x128", + _tcolexpand_kernel, + inputs=lambda: [_make_input()], + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ), +] + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(_tcolexpand_kernel.compile()) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tcolsum.py b/test/dsl-st/npu_a5/tcolsum.py new file mode 100644 index 0000000000..a6d898e46f --- /dev/null +++ b/test/dsl-st/npu_a5/tcolsum.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL reduction pilot for A5: +# tload(src) + tcolsum(src)->dst + tstore(dst) + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +ROWS = 16 +COLS = 128 +SRC_TILE_ADDR = 0 +DST_TILE_ADDR = 8192 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +@_flat_jit( + name="tcolsum_f32_16x128", + kernel_kind="vector", + target="a5", +) +def _tcolsum_kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), +): + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, ROWS, COLS], + strides=[ROWS * COLS, ROWS * COLS, ROWS * COLS, COLS, 1], + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, 1, COLS], + strides=[COLS, COLS, COLS, COLS, 1], + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, ROWS, COLS]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, COLS]) + + src_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32, addr=SRC_TILE_ADDR) + dst_tile = pto.alloc_tile(shape=[1, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) + + pto.tile.load(src_part, src_tile) + pto.tile.colsum(src_tile, dst_tile) + pto.tile.store(dst_tile, dst_part) + + +def _make_input(): + rng = np.random.default_rng(0xC01A5EED) + return rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + + +def _make_expected(src): + return np.sum(src, axis=0, keepdims=True, dtype=np.float32) + + +CASES = [ + golden_output_case( + "tcolsum_f32_16x128", + _tcolsum_kernel, + inputs=lambda: [_make_input()], + expected=_make_expected, + rtol=1e-5, + atol=1e-5, + ), +] + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(_tcolsum_kernel.compile()) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tload_store.py b/test/dsl-st/npu_a5/tload_store.py new file mode 100644 index 0000000000..595ef15183 --- /dev/null +++ b/test/dsl-st/npu_a5/tload_store.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# PTODSL rewrite of the minimal GM -> tile -> GM coverage from +# test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto. +# +# Start with two static f32 round-trips: +# 1. ND / row-major +# 2. DN / col-major +# These are the smallest data-movement cases needed to validate that PTODSL can +# drive tload/tstore on A5 without the tilelang_st harness. + +from pathlib import Path +import sys + +import numpy as np +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl._kernel_compilation import KernelCompiler +from ptodsl._kernel_signature import parse_jit_kernel_signature +from ptodsl._tracing import KernelModuleSpec, ModuleStyle + + +CASE_SPECS = [ + { + "case_name": "nd_f32_16x64", + "kernel_name": "tload_store_nd_f32_16x64", + "shape": (16, 64), + "view_strides": None, + "tile_kwargs": {}, + }, + { + "case_name": "dn_f32_16x64", + "kernel_name": "tload_store_dn_f32_16x64", + "shape": (16, 64), + "view_strides": None, + "tile_kwargs": {"blayout": "ColMajor"}, + }, +] + +TILE_ADDR = 0 + + +class _FlatKernelHandle: + def __init__(self, compiler): + self._compiler = compiler + + def compile(self, **constexpr_bindings): + compiled = self._compiler.compile(**constexpr_bindings) + _attach_flat_vpto_attrs(compiled.build(), self._compiler._module_spec) + return compiled + + +def _attach_flat_vpto_attrs(module, spec): + """Test-local flat containers must carry PTOAS-facing VPTO metadata.""" + with module.context: + module.operation.attributes["pto.backend"] = StringAttr.get(spec.backend) + if spec.backend == "vpto" and spec.kernel_kind in {"cube", "vector"}: + module.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{spec.kernel_kind}>" + ) + + +def _flat_jit(*, name, target="a5", kernel_kind="vector"): + def decorator(fn): + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=name, + target_arch=target, + kernel_kind=kernel_kind, + backend="vpto", + entry=True, + mode="explicit", + insert_sync=None, + module_style=ModuleStyle.FLAT_AICORE, + source_file=__file__, + source_line=fn.__code__.co_firstlineno, + ), + parse_jit_kernel_signature(fn, entry=True), + fn, + ast_rewrite=True, + ) + return _FlatKernelHandle(compiler) + + return decorator + + +def _merge_flat_modules(*compiled_kernels): + first = compiled_kernels[0].build() + with first.context, Location.unknown(): + merged = Module.create() + for named_attr in first.operation.attributes: + merged.operation.attributes[named_attr.name] = named_attr.attr + with InsertionPoint(merged.body): + for compiled in compiled_kernels: + module = compiled.build() + for op in module.body.operations: + op.operation.clone() + merged.operation.verify() + return merged + + +def _roundtrip_body(src_ptr, dst_ptr, *, rows, cols, view_strides=None, tile_kwargs=None): + total = rows * cols + if view_strides is None: + view_strides = [total, total, total, cols, 1] + + src_view = pto.make_tensor_view( + src_ptr, + shape=[1, 1, 1, rows, cols], + strides=view_strides, + ) + dst_view = pto.make_tensor_view( + dst_ptr, + shape=[1, 1, 1, rows, cols], + strides=view_strides, + ) + + src_part = pto.partition_view(src_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + dst_part = pto.partition_view(dst_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, rows, cols]) + + tile = pto.alloc_tile( + shape=[rows, cols], + dtype=pto.f32, + addr=TILE_ADDR, + **(tile_kwargs or {}), + ) + + pto.tile.load(src_part, tile) + pto.tile.store(tile, dst_part) + + +_tload_store_kernels = {} +for _spec in CASE_SPECS: + _rows, _cols = _spec["shape"] + _view_strides = _spec["view_strides"] + if _view_strides is None and _spec["tile_kwargs"].get("blayout") == "ColMajor": + _view_strides = [_rows * _cols, _rows * _cols, _rows * _cols, 1, _rows] + _tile_kwargs = dict(_spec["tile_kwargs"]) + _kernel_name = _spec["kernel_name"] + _case_name = _spec["case_name"] + + def _make(rows=_rows, cols=_cols, view_strides=_view_strides, tile_kwargs=_tile_kwargs, kernel_name=_kernel_name): + @_flat_jit( + name=kernel_name, + kernel_kind="vector", + target="a5", + ) + def _kernel( + src_ptr: pto.ptr(pto.f32, "gm"), + dst_ptr: pto.ptr(pto.f32, "gm"), + ): + _roundtrip_body( + src_ptr, + dst_ptr, + rows=rows, + cols=cols, + view_strides=view_strides, + tile_kwargs=tile_kwargs, + ) + + return _kernel + + _tload_store_kernels[_case_name] = _make() + + +def _make_input(name, shape): + import zlib + + np.random.seed(zlib.crc32(name.encode("utf-8")) & 0xFFFFFFFF) + return np.random.randint(1, 32, size=shape).astype(np.float32) + + +def _make_expected(src): + return np.asarray(src, dtype=np.float32).copy() + + +CASES = [] +for _spec in CASE_SPECS: + _case_name = _spec["case_name"] + _shape = _spec["shape"] + CASES.append( + golden_output_case( + "tload_store_" + _case_name, + _tload_store_kernels[_case_name], + inputs=lambda _case_name=_case_name, _shape=_shape: [_make_input(_case_name, _shape)], + expected=_make_expected, + rtol=1e-6, + atol=1e-6, + ) + ) + + +EMIT_MLIR_FN = lambda: _merge_flat_modules(*[kernel.compile() for kernel in _tload_store_kernels.values()]) + + +auto_main(globals()) diff --git a/test/dsl-st/npu_a5/tmatmul.py b/test/dsl-st/npu_a5/tmatmul.py new file mode 100644 index 0000000000..df2c5fc06e --- /dev/null +++ b/test/dsl-st/npu_a5/tmatmul.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# Minimal PTODSL cube/tmatmul pilot for A5. +# Goal: validate plain cube tile.matmul lowering/runtime first, without mixing +# MX-specific scale/bias handling or @pto.cube helper boundaries. + +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto + + +M = 16 +K = 32 +N = 64 +ELEM_BYTES = 4 + +L1_A_ADDR = 0 +L1_B_ADDR = 4096 +L0A_ADDR = 0 +L0B_ADDR = 0 +L0C_ADDR = 0 + + +@pto.cube +def cube_matmul_tile( + a_mat: pto.Tile, + b_mat: pto.Tile, + o_tile: pto.Tile, + a_l0a: pto.Tile, + b_l0b: pto.Tile, + c_acc: pto.Tile, +): + m = a_mat.valid_shape[0] + k = a_mat.valid_shape[1] + n = b_mat.valid_shape[1] + + pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n, transpose=True) + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=1) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=1) + pto.tile.matmul(a_l0a, b_l0b, c_acc) + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=2) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=2) + pto.mte_l0c_ub( + c_acc.as_ptr(), + o_tile.as_ptr(), + m, + n, + n, + n, + ) + + +@pto.jit( + name="tmatmul_f32_16x32x64", + kernel_kind="cube", + target="a5", + mode="explicit", + insert_sync=False, +) +def _tmatmul_kernel( + a_ptr: pto.ptr(pto.f32, "gm"), + b_ptr: pto.ptr(pto.f32, "gm"), + c_ptr: pto.ptr(pto.f32, "gm"), +): + a_mat = pto.alloc_tile( + shape=[M, K], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + addr=L1_A_ADDR, + valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", + ) + b_mat = pto.alloc_tile( + shape=[K, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + addr=L1_B_ADDR, + valid_shape=[K, N], + blayout="ColMajor", + slayout="RowMajor", + ) + a_l0a = pto.alloc_tile( + shape=[M, K], + dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, + addr=L0A_ADDR, + valid_shape=[M, K], + blayout="ColMajor", + slayout="RowMajor", + ) + b_l0b = pto.alloc_tile( + shape=[K, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.RIGHT, + addr=L0B_ADDR, + valid_shape=[K, N], + blayout="RowMajor", + slayout="ColMajor", + ) + c_acc = pto.alloc_tile( + shape=[M, N], + dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, + addr=L0C_ADDR, + valid_shape=[M, N], + blayout="ColMajor", + slayout="RowMajor", + fractal_size=1024, + ) + + a_l1_ptr = pto.castptr(pto.ui64(L1_A_ADDR), pto.ptr(pto.f32, "mat")) + b_l1_ptr = pto.castptr(pto.ui64(L1_B_ADDR), pto.ptr(pto.f32, "mat")) + + pto.mte_gm_l1_frac( + a_ptr, + a_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(M, K), + src_layout=(K * ELEM_BYTES,), + dst_group=(1, 1, M, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=0) + pto.mte_l1_l0a(a_l1_ptr, a_l0a.as_ptr(), M, K) + + pto.mte_gm_l1_frac( + b_ptr, + b_l1_ptr, + pto.FractalMode.ND2NZ, + shape=(K, N), + src_layout=(N * ELEM_BYTES,), + dst_group=(1, 1, K, 0), + ctrl=(0, False), + ) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE1, event_id=1) + pto.mte_l1_l0b(b_l1_ptr, b_l0b.as_ptr(), K, N, transpose=True) + + pto.set_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.wait_flag(pto.Pipe.MTE1, pto.Pipe.M, event_id=0) + pto.tile.matmul(a_l0a, b_l0b, c_acc) + + pto.set_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.wait_flag(pto.Pipe.M, pto.Pipe.FIX, event_id=1) + pto.mte_l0c_gm( + c_acc.as_ptr(), + c_ptr, + M, + N, + M, + N, + 0, + 0, + layout="nz2nd", + ) + pto.pipe_barrier(pto.Pipe.ALL) + + +def _make_inputs(): + rng = np.random.default_rng(0x7A7A7A71) + a = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float32) + b = rng.uniform(-2.0, 2.0, size=(K, N)).astype(np.float32) + return [a, b] + + +def _make_expected(a, b): + return (a @ b).astype(np.float32) + + +CASES = [ + golden_output_case( + "tmatmul_f32_16x32x64", + _tmatmul_kernel, + inputs=_make_inputs, + expected=_make_expected, + rtol=1e-4, + atol=1e-4, + ), +] + + +auto_main(globals()) diff --git a/test/dsl-st/vmulscvt.py b/test/dsl-st/vmulscvt.py index 4eb6a66449..eaa236870e 100644 --- a/test/dsl-st/vmulscvt.py +++ b/test/dsl-st/vmulscvt.py @@ -15,11 +15,12 @@ - `pto.vmulscvt(..., part=EVEN)` - `pto.vbitcast(..., pto.ui32)` - `pto.vpack(..., LOWER)` -- UB materialization via `pto.vsts` +- UB materialization via `pto.vsts` with a `PAT_VL64` mask -The observable is the packed `u16` register image after the `vmulscvt + vpack` +The observable is the lower 64-lane payload produced by the `vmulscvt + vpack` sequence. That keeps the test close to the C++ authoring style without relying -on `vsstb.post`, which is not available on the current PTODSL surface yet. +on `vsstb.post`, which is not available on the current PTODSL surface yet, and +without asserting anything about lanes outside the authored packed payload. """ from pathlib import Path @@ -35,7 +36,7 @@ SRC_COLS = 64 -OUT_COLS = 128 +OUT_COLS = 64 SCALE = -0.5 @@ -87,7 +88,7 @@ def vmulscvt_pack_kernel( with pto.simd(): mask32 = pto.pset_b32(pto.MaskPattern.ALL) - mask16 = pto.pset_b16(pto.MaskPattern.ALL) + mask16 = pto.pset_b16(pto.MaskPattern.VL64) src = pto.vlds(src_tile[0, 0:]) packed_f16 = pto.vmulscvt( @@ -115,9 +116,7 @@ def make_inputs(): def make_expected(inp): scaled = (inp.astype(np.float32) * np.float32(SCALE)).astype(np.float16).reshape(-1) - packed = np.zeros((OUT_COLS,), dtype=np.uint16) - packed[:SRC_COLS] = scaled.view(np.uint16) - return packed.reshape(1, OUT_COLS) + return scaled.view(np.uint16).reshape(1, OUT_COLS) CASES = [ From 52ad5b5389d3870bb61a7138f687d59312e51df3 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:43:22 +0800 Subject: [PATCH 02/10] Fix PTODSL simulator CI environment handling --- .github/workflows/ci_sim.yml | 100 ++++++++++++++++++++++++++++++++++- scripts/sim_dsl.sh | 45 +++++++++++++++- 2 files changed, 142 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 20862b519e..e0381fc96c 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -538,7 +538,104 @@ jobs: set -euo pipefail mkdir -p "${TILELANG_DSL_WORKSPACE}" export LLVM_BUILD_DIR="${LLVM_DIR}" - export PYTHON_BIN="python3" + export MLIR_PYTHON_ROOT="${MLIR_PYTHONPATH}" + source "${ASCEND_HOME_PATH}/bin/setenv.bash" + + add_python_candidate() { + local candidate="$1" + local resolved + [[ -n "${candidate}" ]] || return 0 + if [[ "${candidate}" != */* ]]; then + resolved="$(command -v "${candidate}" 2>/dev/null || true)" + else + resolved="${candidate}" + fi + [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 + resolved="$(readlink -f "${resolved}")" + case ":${PYTHON_CANDIDATES[*]}:" in + *":${resolved}:"*) return 0 ;; + esac + PYTHON_CANDIDATES+=("${resolved}") + } + + ptoas_python_site() { + "$1" - <<'PY' + import os + import sysconfig + + prefix = os.environ["PTO_INSTALL_DIR"] + print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) + PY + } + + probe_ptodsl_python() { + local candidate="$1" + local site_path="$2" + PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_INSTALL_DIR}:${site_path}:${MLIR_PYTHONPATH}:${GITHUB_WORKSPACE}/build/python:${PYTHONPATH:-}" \ + "${candidate}" - <<'PY' + import sys + import torch + import torch_npu # noqa: F401 + from ptodsl import pto # noqa: F401 + from mlir.dialects import pto as _pto # noqa: F401 + + print(sys.executable) + print("torch", torch.__version__) + print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + PY + } + + PYTHON_CANDIDATES=() + if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then + add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" + else + add_python_candidate python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python + add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 + add_python_candidate /home/zhoujiaming/miniconda3/bin/python + shopt -s nullglob + for candidate in \ + /home/*/miniconda3/envs/*/bin/python \ + /home/*/anaconda3/envs/*/bin/python \ + /opt/conda/envs/*/bin/python + do + add_python_candidate "${candidate}" + done + shopt -u nullglob + fi + + SELECTED_PYTHON="" + SELECTED_PYTHON_SITE="" + PROBE_LOG="${TILELANG_DSL_WORKSPACE}/ptodsl-python-probe.log" + : > "${PROBE_LOG}" + for candidate in "${PYTHON_CANDIDATES[@]}"; do + echo "Probing PTODSL DSL ST Python: ${candidate}" | tee -a "${PROBE_LOG}" + if ! site_path="$(ptoas_python_site "${candidate}" 2>> "${PROBE_LOG}")"; then + echo "PTODSL Python site-path probe failed for ${candidate}" | tee -a "${PROBE_LOG}" + continue + fi + if probe_ptodsl_python "${candidate}" "${site_path}" >> "${PROBE_LOG}" 2>&1; then + SELECTED_PYTHON="${candidate}" + SELECTED_PYTHON_SITE="${site_path}" + break + fi + echo "PTODSL Python probe failed for ${candidate}" | tee -a "${PROBE_LOG}" + done + + if [[ -z "${SELECTED_PYTHON}" ]]; then + cat "${PROBE_LOG}" >&2 + echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 + echo "ERROR: this workflow intentionally does not install torch_npu on every run." >&2 + echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 + exit 1 + fi + + cat "${PROBE_LOG}" + export PYTHON_BIN="${SELECTED_PYTHON}" + export PTO_PYTHON_BIN="${SELECTED_PYTHON}" + export PTOAS_PYTHON_SITE="${SELECTED_PYTHON_SITE}" + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ PTOAS_BIN="${PTOAS_BIN}" \ scripts/sim_dsl.sh test/dsl-st \ @@ -552,6 +649,7 @@ jobs: path: | ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-dsl-st.log + ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-python-probe.log if-no-files-found: warn - name: Run TileLang DSL unit tests diff --git a/scripts/sim_dsl.sh b/scripts/sim_dsl.sh index cd0bccbdf8..5632fb2e92 100755 --- a/scripts/sim_dsl.sh +++ b/scripts/sim_dsl.sh @@ -35,6 +35,8 @@ Environment: Keep the private staging directory after a successful sync. PTOAS_MSPROF_LOG_MODE=quiet|verbose Override the default simulator log rendering mode. + PYTHON_BIN Python executable used for the PTODSL example. + Defaults to python3. Examples: scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py @@ -177,12 +179,32 @@ ensure_private_dir "${PRIVATE_ROOT}" RUNTIME_OUTPUT_DIR="$(mktemp -d "${PRIVATE_ROOT}/${EXAMPLE_STEM}.XXXXXX")" chmod 700 "${RUNTIME_OUTPUT_DIR}" MSPROF_STDIO_LOG="${RUNTIME_OUTPUT_DIR}/msprof.stdout.log" +EXAMPLE_EXIT_CODE_FILE="${RUNTIME_OUTPUT_DIR}/example.exitcode" +EXAMPLE_LAUNCHER="${RUNTIME_OUTPUT_DIR}/run_example.sh" +PYTHON_BIN="${PTO_PYTHON_BIN:-${PYTHON_BIN:-python3}}" source "${ASCEND_HOME_PATH}/bin/setenv.bash" source "${REPO_ROOT}/scripts/ptoas_env.sh" export LD_LIBRARY_PATH="${SIM_LIB_DIR}:${LD_LIBRARY_PATH:-}" ulimit -n 65535 +if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then + die "PYTHON_BIN is not executable or not found on PATH: ${PYTHON_BIN}" +fi + +cat > "${EXAMPLE_LAUNCHER}" <<'EOF' +#!/usr/bin/env bash +set +e +"${PTOAS_SIM_DSL_PYTHON_BIN}" "${PTOAS_SIM_DSL_EXAMPLE_PATH}" "$@" +status=$? +printf '%s\n' "${status}" > "${PTOAS_SIM_DSL_EXIT_CODE_FILE}" +exit "${status}" +EOF +chmod 700 "${EXAMPLE_LAUNCHER}" +export PTOAS_SIM_DSL_PYTHON_BIN="${PYTHON_BIN}" +export PTOAS_SIM_DSL_EXAMPLE_PATH="${EXAMPLE_PATH}" +export PTOAS_SIM_DSL_EXIT_CODE_FILE="${EXAMPLE_EXIT_CODE_FILE}" + # msprof rejects group/other-writable working directories, so always launch # from a private directory and use an absolute path for the example script. cd "${HOME}" @@ -196,11 +218,30 @@ set +e msprof op simulator \ --soc-version="${SOC_VERSION}" \ --output="${RUNTIME_OUTPUT_DIR}" \ - python3 "${EXAMPLE_PATH}" "${EXAMPLE_ARGS[@]}" \ + "${EXAMPLE_LAUNCHER}" "${EXAMPLE_ARGS[@]}" \ > "${MSPROF_STDIO_LOG}" 2>&1 -STATUS=$? +MSPROF_STATUS=$? set -e +EXAMPLE_STATUS=0 +if [[ -f "${EXAMPLE_EXIT_CODE_FILE}" ]]; then + EXAMPLE_STATUS="$(< "${EXAMPLE_EXIT_CODE_FILE}")" + if [[ ! "${EXAMPLE_STATUS}" =~ ^[0-9]+$ ]]; then + log "invalid example exit code recorded in ${EXAMPLE_EXIT_CODE_FILE}: ${EXAMPLE_STATUS}" + EXAMPLE_STATUS=1 + fi +else + log "example exit code file was not produced: ${EXAMPLE_EXIT_CODE_FILE}" + EXAMPLE_STATUS=1 +fi + +STATUS=0 +if [[ ${MSPROF_STATUS} -ne 0 ]]; then + STATUS=${MSPROF_STATUS} +elif [[ ${EXAMPLE_STATUS} -ne 0 ]]; then + STATUS=${EXAMPLE_STATUS} +fi + print_msprof_log "${MSPROF_STDIO_LOG}" "${MSPROF_LOG_MODE}" "${STATUS}" SYNC_STATUS=0 From 972204f3b340ec18a386e0f9d0d80e62b7e71285 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Wed, 1 Jul 2026 18:01:48 +0800 Subject: [PATCH 03/10] Fix PTODSL DSL ST simulator CI --- .github/workflows/ci_sim.yml | 373 ++++++++++++++++++++++------------- 1 file changed, 239 insertions(+), 134 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index e0381fc96c..a7530e1d1e 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -125,37 +125,33 @@ jobs: fetch-depth: 1 persist-credentials: false - - name: Resolve LLVM directories + - name: Resolve simulator environment shell: bash run: | set -euo pipefail - echo "LLVM_ROOT=${RUNNER_TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" - echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert" >> "${GITHUB_ENV}" - echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" - - name: Resolve LLVM cache key - id: llvm-cache-key - shell: bash - run: | - set -euo pipefail - # Resolve to a Git object that ls-remote can handle: either a tag - # (LLVM_TAG) or a branch head (LLVM_REF). Only one is expected. - ref="${LLVM_TAG:-${LLVM_REF}}" - sha="$(git ls-remote "${LLVM_REPO}" "${ref}" | awk '{print $1}')" - if [[ -z "${sha}" ]]; then - echo "ERROR: failed to resolve ${LLVM_REPO} ${ref}" >&2 + detect_ascend_home() { + for d in \ + "${ASCEND_HOME_PATH:-}" \ + /usr/local/Ascend/cann \ + /usr/local/Ascend/cann-* \ + /usr/local/Ascend/ascend-toolkit/latest + do + [[ -n "${d}" && -d "${d}" ]] || continue + printf '%s\n' "${d}" + return 0 + done + return 1 + } + + ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" + if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then + echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 exit 1 fi - echo "sha=${sha}" >> "${GITHUB_OUTPUT}" - echo "key=llvm-build-${sha}-assert-v1" >> "${GITHUB_OUTPUT}" - - name: Restore LLVM build cache - id: llvm-cache - continue-on-error: true - uses: actions/cache/restore@v4 - with: - path: ${{ env.LLVM_DIR }} - key: ${{ steps.llvm-cache-key.outputs.key }} + echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" + echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - name: Ensure runner dependencies shell: bash @@ -201,6 +197,200 @@ jobs: python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi + - name: Resolve PTODSL Python + shell: bash + run: | + set -euo pipefail + source "${ASCEND_HOME_PATH}/bin/setenv.bash" + + # PTODSL imports MLIR Python bindings while its ST runtime imports + # torch_npu. Resolve the torch_npu interpreter first, then build + # LLVM/PTOAS Python bindings with the same Python ABI. + + add_python_candidate() { + local candidate="$1" + local resolved + [[ -n "${candidate}" ]] || return 0 + if [[ "${candidate}" != */* ]]; then + resolved="$(command -v "${candidate}" 2>/dev/null || true)" + else + resolved="${candidate}" + fi + [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 + resolved="$(readlink -f "${resolved}")" + case ":${PYTHON_CANDIDATES[*]}:" in + *":${resolved}:"*) return 0 ;; + esac + PYTHON_CANDIDATES+=("${resolved}") + } + + has_torch_npu_packages() { + "$1" - <<'PY' + import importlib.util + + missing = [ + name for name in ("torch", "torch_npu") + if importlib.util.find_spec(name) is None + ] + raise SystemExit(1 if missing else 0) + PY + } + + probe_ptodsl_runtime_python() { + TORCH_DEVICE_BACKEND_AUTOLOAD=0 "$1" - <<'PY' + import sys + import numpy + import pybind11 + import yaml + import torch + import torch_npu # noqa: F401 + + print(sys.executable) + print(f"python {sys.version_info.major}.{sys.version_info.minor}") + print("torch", torch.__version__) + print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + print("numpy", numpy.__version__) + print("pybind11", pybind11.__version__) + PY + } + + missing_ptodsl_python_deps() { + "$1" - <<'PY' + import importlib.util + + deps = [ + ("setuptools", "setuptools"), + ("wheel", "wheel"), + ("numpy", "numpy"), + ("ml_dtypes", "ml-dtypes"), + ("yaml", "PyYAML"), + ] + + missing = [ + requirement for module_name, requirement in deps + if importlib.util.find_spec(module_name) is None + ] + + try: + import pybind11 + version = tuple(int(part) for part in pybind11.__version__.split(".")[:1]) + if version >= (3,): + missing.append("pybind11<3") + except Exception: + missing.append("pybind11<3") + + print(" ".join(missing)) + PY + } + + PYTHON_CANDIDATES=() + if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then + add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" + else + add_python_candidate python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python + add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 + add_python_candidate /home/zhoujiaming/miniconda3/bin/python + shopt -s nullglob + for candidate in \ + /home/*/miniconda3/envs/*/bin/python \ + /home/*/anaconda3/envs/*/bin/python \ + /opt/conda/envs/*/bin/python + do + add_python_candidate "${candidate}" + done + shopt -u nullglob + fi + + SELECTED_BASE_PYTHON="" + for candidate in "${PYTHON_CANDIDATES[@]}"; do + echo "Probing PTODSL base Python package presence: ${candidate}" + if has_torch_npu_packages "${candidate}"; then + SELECTED_BASE_PYTHON="${candidate}" + break + fi + done + + if [[ -z "${SELECTED_BASE_PYTHON}" ]]; then + echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 + echo "ERROR: this workflow intentionally does not install torch or torch_npu on every run." >&2 + echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 + exit 1 + fi + + PYTHON_ABI_TAG="$("${SELECTED_BASE_PYTHON}" - <<'PY' + import sys + print(f"py{sys.version_info.major}{sys.version_info.minor}") + PY + )" + BASE_HASH="$(printf '%s' "${SELECTED_BASE_PYTHON}" | sha256sum | cut -c1-12)" + PYTHON_TAG="${PYTHON_ABI_TAG}-${BASE_HASH}" + PTODSL_PYTHON_ROOT="${RUNNER_TOOL_CACHE}/ptodsl-python/${PYTHON_TAG}" + if [[ ! -x "${PTODSL_PYTHON_ROOT}/bin/python" ]]; then + rm -rf "${PTODSL_PYTHON_ROOT}" + "${SELECTED_BASE_PYTHON}" -m venv --system-site-packages "${PTODSL_PYTHON_ROOT}" + fi + + PTODSL_PYTHON="${PTODSL_PYTHON_ROOT}/bin/python" + if ! "${PTODSL_PYTHON}" -m pip --version >/dev/null 2>&1; then + "${PTODSL_PYTHON}" -m ensurepip --upgrade + fi + missing_deps="$(missing_ptodsl_python_deps "${PTODSL_PYTHON}")" + if [[ -n "${missing_deps}" ]]; then + "${PTODSL_PYTHON}" -m pip install ${missing_deps} + fi + + probe_ptodsl_runtime_python "${PTODSL_PYTHON}" + + PTODSL_PYTHON_SITE="$( + PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" "${PTODSL_PYTHON}" - <<'PY' + import os + import sysconfig + + prefix = os.environ["PTO_INSTALL_DIR"] + print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) + PY + )" + + echo "PTO_DSL_ST_BASE_PYTHON=${SELECTED_BASE_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_BIN=${PTODSL_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_TAG=${PYTHON_TAG}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_SITE=${PTODSL_PYTHON_SITE}" >> "${GITHUB_ENV}" + + - name: Resolve LLVM directories + shell: bash + run: | + set -euo pipefail + python_tag="${PTO_DSL_ST_PYTHON_TAG:-default}" + echo "LLVM_ROOT=${RUNNER_TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" + echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${python_tag}" >> "${GITHUB_ENV}" + echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${python_tag}/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + + - name: Resolve LLVM cache key + id: llvm-cache-key + shell: bash + run: | + set -euo pipefail + # Resolve to a Git object that ls-remote can handle: either a tag + # (LLVM_TAG) or a branch head (LLVM_REF). Only one is expected. + ref="${LLVM_TAG:-${LLVM_REF}}" + sha="$(git ls-remote "${LLVM_REPO}" "${ref}" | awk '{print $1}')" + if [[ -z "${sha}" ]]; then + echo "ERROR: failed to resolve ${LLVM_REPO} ${ref}" >&2 + exit 1 + fi + echo "sha=${sha}" >> "${GITHUB_OUTPUT}" + echo "key=llvm-build-${sha}-assert-${PTO_DSL_ST_PYTHON_TAG:-default}-v1" >> "${GITHUB_OUTPUT}" + + - name: Restore LLVM build cache + id: llvm-cache + continue-on-error: true + uses: actions/cache/restore@v4 + with: + path: ${{ env.LLVM_DIR }} + key: ${{ steps.llvm-cache-key.outputs.key }} + - name: Clean CI work dirs shell: bash run: | @@ -245,20 +435,20 @@ jobs: # Clean the build directory so that stale generated files from a # previous run (e.g. _smt_ops_gen.py left behind when the ref # changed) do not leak into the fresh build. - rm -rf llvm/build-assert - cmake -G Ninja -S llvm -B llvm/build-assert \ + rm -rf "${LLVM_DIR}" + cmake -G Ninja -S llvm -B "${LLVM_DIR}" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DBUILD_SHARED_LIBS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DPython3_EXECUTABLE=python3 \ - -DPython_EXECUTABLE=python3 \ - -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ - -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ + -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DPython_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -Dpybind11_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" - ninja -C llvm/build-assert + ninja -C "${LLVM_DIR}" - name: Save LLVM build cache if: steps.llvm-cache.outputs.cache-hit != 'true' @@ -279,35 +469,7 @@ jobs: # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). LLVM_BUILD_DIR="${LLVM_DIR}" \ PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" \ - python3 -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" - - - name: Resolve simulator environment - shell: bash - run: | - set -euo pipefail - - detect_ascend_home() { - for d in \ - "${ASCEND_HOME_PATH:-}" \ - /usr/local/Ascend/cann \ - /usr/local/Ascend/cann-* \ - /usr/local/Ascend/ascend-toolkit/latest - do - [[ -n "${d}" && -d "${d}" ]] || continue - printf '%s\n' "${d}" - return 0 - done - return 1 - } - - ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" - if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then - echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 - exit 1 - fi - - echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" - echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" + "${PTO_DSL_ST_PYTHON_BIN}" -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" - name: Checkout PyPTO uses: actions/checkout@v4 @@ -539,41 +701,17 @@ jobs: mkdir -p "${TILELANG_DSL_WORKSPACE}" export LLVM_BUILD_DIR="${LLVM_DIR}" export MLIR_PYTHON_ROOT="${MLIR_PYTHONPATH}" + export PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" + export PTO_PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" + export PTOAS_PYTHON_SITE="${PTO_DSL_ST_PYTHON_SITE}" + export TORCH_DEVICE_BACKEND_AUTOLOAD=0 source "${ASCEND_HOME_PATH}/bin/setenv.bash" - add_python_candidate() { - local candidate="$1" - local resolved - [[ -n "${candidate}" ]] || return 0 - if [[ "${candidate}" != */* ]]; then - resolved="$(command -v "${candidate}" 2>/dev/null || true)" - else - resolved="${candidate}" - fi - [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 - resolved="$(readlink -f "${resolved}")" - case ":${PYTHON_CANDIDATES[*]}:" in - *":${resolved}:"*) return 0 ;; - esac - PYTHON_CANDIDATES+=("${resolved}") - } - - ptoas_python_site() { - "$1" - <<'PY' - import os - import sysconfig - - prefix = os.environ["PTO_INSTALL_DIR"] - print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) - PY - } - probe_ptodsl_python() { - local candidate="$1" - local site_path="$2" - PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_INSTALL_DIR}:${site_path}:${MLIR_PYTHONPATH}:${GITHUB_WORKSPACE}/build/python:${PYTHONPATH:-}" \ - "${candidate}" - <<'PY' + PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_INSTALL_DIR}:${PTOAS_PYTHON_SITE}:${MLIR_PYTHONPATH}:${GITHUB_WORKSPACE}/build/python:${PYTHONPATH:-}" \ + "${PTO_DSL_ST_PYTHON_BIN}" - <<'PY' import sys + import numpy import torch import torch_npu # noqa: F401 from ptodsl import pto # noqa: F401 @@ -582,65 +720,31 @@ jobs: print(sys.executable) print("torch", torch.__version__) print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + print("numpy", numpy.__version__) PY } - PYTHON_CANDIDATES=() - if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then - add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" - else - add_python_candidate python3 - add_python_candidate /home/mouliangyu/miniconda3/bin/python3 - add_python_candidate /home/mouliangyu/miniconda3/bin/python - add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 - add_python_candidate /home/zhoujiaming/miniconda3/bin/python - shopt -s nullglob - for candidate in \ - /home/*/miniconda3/envs/*/bin/python \ - /home/*/anaconda3/envs/*/bin/python \ - /opt/conda/envs/*/bin/python - do - add_python_candidate "${candidate}" - done - shopt -u nullglob - fi - - SELECTED_PYTHON="" - SELECTED_PYTHON_SITE="" PROBE_LOG="${TILELANG_DSL_WORKSPACE}/ptodsl-python-probe.log" : > "${PROBE_LOG}" - for candidate in "${PYTHON_CANDIDATES[@]}"; do - echo "Probing PTODSL DSL ST Python: ${candidate}" | tee -a "${PROBE_LOG}" - if ! site_path="$(ptoas_python_site "${candidate}" 2>> "${PROBE_LOG}")"; then - echo "PTODSL Python site-path probe failed for ${candidate}" | tee -a "${PROBE_LOG}" - continue - fi - if probe_ptodsl_python "${candidate}" "${site_path}" >> "${PROBE_LOG}" 2>&1; then - SELECTED_PYTHON="${candidate}" - SELECTED_PYTHON_SITE="${site_path}" - break - fi - echo "PTODSL Python probe failed for ${candidate}" | tee -a "${PROBE_LOG}" - done - - if [[ -z "${SELECTED_PYTHON}" ]]; then + echo "Probing PTODSL DSL ST Python: ${PTO_DSL_ST_PYTHON_BIN}" | tee -a "${PROBE_LOG}" + if ! probe_ptodsl_python >> "${PROBE_LOG}" 2>&1; then cat "${PROBE_LOG}" >&2 - echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 - echo "ERROR: this workflow intentionally does not install torch_npu on every run." >&2 - echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 + echo "ERROR: selected PTODSL Python cannot import torch, torch_npu, ptodsl, and MLIR PTO bindings." >&2 exit 1 fi cat "${PROBE_LOG}" - export PYTHON_BIN="${SELECTED_PYTHON}" - export PTO_PYTHON_BIN="${SELECTED_PYTHON}" - export PTOAS_PYTHON_SITE="${SELECTED_PYTHON_SITE}" ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ PTOAS_BIN="${PTOAS_BIN}" \ scripts/sim_dsl.sh test/dsl-st \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st.log" + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + scripts/sim_dsl.sh test/dsl-st/npu_a5 \ + 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st-npu-a5.log" + - name: Upload TileLang DSL logs if: always() uses: actions/upload-artifact@v4 @@ -649,6 +753,7 @@ jobs: path: | ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-dsl-st.log + ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-dsl-st-npu-a5.log ${{ env.TILELANG_DSL_WORKSPACE }}/ptodsl-python-probe.log if-no-files-found: warn From adf828258c897fbe25e54de05a1a2905526db0cd Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 08:58:59 +0800 Subject: [PATCH 04/10] Isolate PTODSL simulator CI toolchain --- .github/workflows/ci_sim.yml | 396 ++++++++++++++++++++--------------- 1 file changed, 225 insertions(+), 171 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index a7530e1d1e..8babd6ab24 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -197,175 +197,16 @@ jobs: python3 -m pip install setuptools wheel 'pybind11<3' nanobind numpy ml-dtypes fi - - name: Resolve PTODSL Python - shell: bash - run: | - set -euo pipefail - source "${ASCEND_HOME_PATH}/bin/setenv.bash" - - # PTODSL imports MLIR Python bindings while its ST runtime imports - # torch_npu. Resolve the torch_npu interpreter first, then build - # LLVM/PTOAS Python bindings with the same Python ABI. - - add_python_candidate() { - local candidate="$1" - local resolved - [[ -n "${candidate}" ]] || return 0 - if [[ "${candidate}" != */* ]]; then - resolved="$(command -v "${candidate}" 2>/dev/null || true)" - else - resolved="${candidate}" - fi - [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 - resolved="$(readlink -f "${resolved}")" - case ":${PYTHON_CANDIDATES[*]}:" in - *":${resolved}:"*) return 0 ;; - esac - PYTHON_CANDIDATES+=("${resolved}") - } - - has_torch_npu_packages() { - "$1" - <<'PY' - import importlib.util - - missing = [ - name for name in ("torch", "torch_npu") - if importlib.util.find_spec(name) is None - ] - raise SystemExit(1 if missing else 0) - PY - } - - probe_ptodsl_runtime_python() { - TORCH_DEVICE_BACKEND_AUTOLOAD=0 "$1" - <<'PY' - import sys - import numpy - import pybind11 - import yaml - import torch - import torch_npu # noqa: F401 - - print(sys.executable) - print(f"python {sys.version_info.major}.{sys.version_info.minor}") - print("torch", torch.__version__) - print("torch_npu", getattr(torch_npu, "__version__", "unknown")) - print("numpy", numpy.__version__) - print("pybind11", pybind11.__version__) - PY - } - - missing_ptodsl_python_deps() { - "$1" - <<'PY' - import importlib.util - - deps = [ - ("setuptools", "setuptools"), - ("wheel", "wheel"), - ("numpy", "numpy"), - ("ml_dtypes", "ml-dtypes"), - ("yaml", "PyYAML"), - ] - - missing = [ - requirement for module_name, requirement in deps - if importlib.util.find_spec(module_name) is None - ] - - try: - import pybind11 - version = tuple(int(part) for part in pybind11.__version__.split(".")[:1]) - if version >= (3,): - missing.append("pybind11<3") - except Exception: - missing.append("pybind11<3") - - print(" ".join(missing)) - PY - } - - PYTHON_CANDIDATES=() - if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then - add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" - else - add_python_candidate python3 - add_python_candidate /home/mouliangyu/miniconda3/bin/python3 - add_python_candidate /home/mouliangyu/miniconda3/bin/python - add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 - add_python_candidate /home/zhoujiaming/miniconda3/bin/python - shopt -s nullglob - for candidate in \ - /home/*/miniconda3/envs/*/bin/python \ - /home/*/anaconda3/envs/*/bin/python \ - /opt/conda/envs/*/bin/python - do - add_python_candidate "${candidate}" - done - shopt -u nullglob - fi - - SELECTED_BASE_PYTHON="" - for candidate in "${PYTHON_CANDIDATES[@]}"; do - echo "Probing PTODSL base Python package presence: ${candidate}" - if has_torch_npu_packages "${candidate}"; then - SELECTED_BASE_PYTHON="${candidate}" - break - fi - done - - if [[ -z "${SELECTED_BASE_PYTHON}" ]]; then - echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 - echo "ERROR: this workflow intentionally does not install torch or torch_npu on every run." >&2 - echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 - exit 1 - fi - - PYTHON_ABI_TAG="$("${SELECTED_BASE_PYTHON}" - <<'PY' - import sys - print(f"py{sys.version_info.major}{sys.version_info.minor}") - PY - )" - BASE_HASH="$(printf '%s' "${SELECTED_BASE_PYTHON}" | sha256sum | cut -c1-12)" - PYTHON_TAG="${PYTHON_ABI_TAG}-${BASE_HASH}" - PTODSL_PYTHON_ROOT="${RUNNER_TOOL_CACHE}/ptodsl-python/${PYTHON_TAG}" - if [[ ! -x "${PTODSL_PYTHON_ROOT}/bin/python" ]]; then - rm -rf "${PTODSL_PYTHON_ROOT}" - "${SELECTED_BASE_PYTHON}" -m venv --system-site-packages "${PTODSL_PYTHON_ROOT}" - fi - - PTODSL_PYTHON="${PTODSL_PYTHON_ROOT}/bin/python" - if ! "${PTODSL_PYTHON}" -m pip --version >/dev/null 2>&1; then - "${PTODSL_PYTHON}" -m ensurepip --upgrade - fi - missing_deps="$(missing_ptodsl_python_deps "${PTODSL_PYTHON}")" - if [[ -n "${missing_deps}" ]]; then - "${PTODSL_PYTHON}" -m pip install ${missing_deps} - fi - - probe_ptodsl_runtime_python "${PTODSL_PYTHON}" - - PTODSL_PYTHON_SITE="$( - PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" "${PTODSL_PYTHON}" - <<'PY' - import os - import sysconfig - - prefix = os.environ["PTO_INSTALL_DIR"] - print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) - PY - )" - - echo "PTO_DSL_ST_BASE_PYTHON=${SELECTED_BASE_PYTHON}" >> "${GITHUB_ENV}" - echo "PTO_DSL_ST_PYTHON_BIN=${PTODSL_PYTHON}" >> "${GITHUB_ENV}" - echo "PTO_DSL_ST_PYTHON_TAG=${PYTHON_TAG}" >> "${GITHUB_ENV}" - echo "PTO_DSL_ST_PYTHON_SITE=${PTODSL_PYTHON_SITE}" >> "${GITHUB_ENV}" - - name: Resolve LLVM directories shell: bash run: | set -euo pipefail - python_tag="${PTO_DSL_ST_PYTHON_TAG:-default}" echo "LLVM_ROOT=${RUNNER_TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" - echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${python_tag}" >> "${GITHUB_ENV}" - echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${python_tag}/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + echo "LLVM_DIR=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert" >> "${GITHUB_ENV}" + echo "MLIR_PYTHONPATH=${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_BUILD_DIR=${GITHUB_WORKSPACE}/build-ptodsl" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_INSTALL_DIR=${GITHUB_WORKSPACE}/install-ptodsl" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PTOAS_BIN=${GITHUB_WORKSPACE}/build-ptodsl/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - name: Resolve LLVM cache key id: llvm-cache-key @@ -381,7 +222,7 @@ jobs: exit 1 fi echo "sha=${sha}" >> "${GITHUB_OUTPUT}" - echo "key=llvm-build-${sha}-assert-${PTO_DSL_ST_PYTHON_TAG:-default}-v1" >> "${GITHUB_OUTPUT}" + echo "key=llvm-build-${sha}-assert-v1" >> "${GITHUB_OUTPUT}" - name: Restore LLVM build cache id: llvm-cache @@ -396,7 +237,9 @@ jobs: run: | set -euo pipefail rm -rf "${GITHUB_WORKSPACE}/build" + rm -rf "${GITHUB_WORKSPACE}/build-ptodsl" rm -rf "${PTO_INSTALL_DIR}" + rm -rf "${PTO_DSL_ST_INSTALL_DIR}" rm -rf "${VPTO_SIM_WORKSPACE}" rm -rf "${TILELANG_DSL_WORKSPACE}" rm -rf "${PYPTO_WORKSPACE}" @@ -469,7 +312,7 @@ jobs: # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). LLVM_BUILD_DIR="${LLVM_DIR}" \ PTO_INSTALL_DIR="${PTO_INSTALL_DIR}" \ - "${PTO_DSL_ST_PYTHON_BIN}" -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" + python3 -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_INSTALL_DIR}" - name: Checkout PyPTO uses: actions/checkout@v4 @@ -694,13 +537,224 @@ jobs: 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" fi + - name: Resolve PTODSL Python + shell: bash + run: | + set -euo pipefail + source "${ASCEND_HOME_PATH}/bin/setenv.bash" + + add_python_candidate() { + local candidate="$1" + local resolved + [[ -n "${candidate}" ]] || return 0 + if [[ "${candidate}" != */* ]]; then + resolved="$(command -v "${candidate}" 2>/dev/null || true)" + else + resolved="${candidate}" + fi + [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 + resolved="$(readlink -f "${resolved}")" + case ":${PYTHON_CANDIDATES[*]}:" in + *":${resolved}:"*) return 0 ;; + esac + PYTHON_CANDIDATES+=("${resolved}") + } + + has_torch_npu_packages() { + "$1" - <<'PY' + import importlib.util + + missing = [ + name for name in ("torch", "torch_npu") + if importlib.util.find_spec(name) is None + ] + raise SystemExit(1 if missing else 0) + PY + } + + probe_ptodsl_runtime_python() { + TORCH_DEVICE_BACKEND_AUTOLOAD=0 "$1" - <<'PY' + import sys + import numpy + import pybind11 + import yaml + import torch + import torch_npu # noqa: F401 + + print(sys.executable) + print(f"python {sys.version_info.major}.{sys.version_info.minor}") + print("torch", torch.__version__) + print("torch_npu", getattr(torch_npu, "__version__", "unknown")) + print("numpy", numpy.__version__) + print("pybind11", pybind11.__version__) + PY + } + + missing_ptodsl_python_deps() { + "$1" - <<'PY' + import importlib.util + + deps = [ + ("setuptools", "setuptools"), + ("wheel", "wheel"), + ("numpy", "numpy"), + ("ml_dtypes", "ml-dtypes"), + ("yaml", "PyYAML"), + ] + + missing = [ + requirement for module_name, requirement in deps + if importlib.util.find_spec(module_name) is None + ] + + try: + import pybind11 + version = tuple(int(part) for part in pybind11.__version__.split(".")[:1]) + if version >= (3,): + missing.append("pybind11<3") + except Exception: + missing.append("pybind11<3") + + print(" ".join(missing)) + PY + } + + PYTHON_CANDIDATES=() + if [[ -n "${PTO_DSL_ST_PYTHON_BIN:-}" ]]; then + add_python_candidate "${PTO_DSL_ST_PYTHON_BIN}" + else + add_python_candidate python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python3 + add_python_candidate /home/mouliangyu/miniconda3/bin/python + add_python_candidate /home/zhoujiaming/miniconda3/bin/python3 + add_python_candidate /home/zhoujiaming/miniconda3/bin/python + shopt -s nullglob + for candidate in \ + /home/*/miniconda3/envs/*/bin/python \ + /home/*/anaconda3/envs/*/bin/python \ + /opt/conda/envs/*/bin/python + do + add_python_candidate "${candidate}" + done + shopt -u nullglob + fi + + SELECTED_BASE_PYTHON="" + for candidate in "${PYTHON_CANDIDATES[@]}"; do + echo "Probing PTODSL base Python package presence: ${candidate}" + if has_torch_npu_packages "${candidate}"; then + SELECTED_BASE_PYTHON="${candidate}" + break + fi + done + + if [[ -z "${SELECTED_BASE_PYTHON}" ]]; then + echo "ERROR: PTODSL DSL ST requires an existing Python runtime with torch and torch_npu." >&2 + echo "ERROR: this workflow intentionally does not install torch or torch_npu on every run." >&2 + echo "ERROR: set PTO_DSL_ST_PYTHON_BIN to a compatible pre-installed interpreter." >&2 + exit 1 + fi + + PYTHON_ABI_TAG="$("${SELECTED_BASE_PYTHON}" - <<'PY' + import sys + print(f"py{sys.version_info.major}{sys.version_info.minor}") + PY + )" + BASE_HASH="$(printf '%s' "${SELECTED_BASE_PYTHON}" | sha256sum | cut -c1-12)" + PYTHON_TAG="${PYTHON_ABI_TAG}-${BASE_HASH}" + PTODSL_PYTHON_ROOT="${RUNNER_TOOL_CACHE}/ptodsl-python/${PYTHON_TAG}" + if [[ ! -x "${PTODSL_PYTHON_ROOT}/bin/python" ]]; then + rm -rf "${PTODSL_PYTHON_ROOT}" + "${SELECTED_BASE_PYTHON}" -m venv --system-site-packages "${PTODSL_PYTHON_ROOT}" + fi + + PTODSL_PYTHON="${PTODSL_PYTHON_ROOT}/bin/python" + if ! "${PTODSL_PYTHON}" -m pip --version >/dev/null 2>&1; then + "${PTODSL_PYTHON}" -m ensurepip --upgrade + fi + missing_deps="$(missing_ptodsl_python_deps "${PTODSL_PYTHON}")" + if [[ -n "${missing_deps}" ]]; then + "${PTODSL_PYTHON}" -m pip install ${missing_deps} + fi + + probe_ptodsl_runtime_python "${PTODSL_PYTHON}" + + PTO_DSL_ST_LLVM_DIR="${RUNNER_TOOL_CACHE}/llvm-project/llvm/build-assert-${PYTHON_TAG}" + PTO_DSL_ST_MLIR_PYTHONPATH="${PTO_DSL_ST_LLVM_DIR}/tools/mlir/python_packages/mlir_core" + PTO_DSL_ST_PYTHON_SITE="$( + PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" "${PTODSL_PYTHON}" - <<'PY' + import os + import sysconfig + + prefix = os.environ["PTO_INSTALL_DIR"] + print(sysconfig.get_path("purelib", vars={"base": prefix, "platbase": prefix})) + PY + )" + + echo "PTO_DSL_ST_BASE_PYTHON=${SELECTED_BASE_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_BIN=${PTODSL_PYTHON}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_TAG=${PYTHON_TAG}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_LLVM_DIR=${PTO_DSL_ST_LLVM_DIR}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_MLIR_PYTHONPATH=${PTO_DSL_ST_MLIR_PYTHONPATH}" >> "${GITHUB_ENV}" + echo "PTO_DSL_ST_PYTHON_SITE=${PTO_DSL_ST_PYTHON_SITE}" >> "${GITHUB_ENV}" + + - name: Restore PTODSL LLVM build cache + id: ptodsl-llvm-cache + continue-on-error: true + uses: actions/cache/restore@v4 + with: + path: ${{ env.PTO_DSL_ST_LLVM_DIR }} + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + + - name: Build PTODSL LLVM/MLIR + if: steps.ptodsl-llvm-cache.outputs.cache-hit != 'true' + shell: bash + run: | + set -euo pipefail + cd "${LLVM_ROOT}" + export CC=gcc + export CXX=g++ + rm -rf "${PTO_DSL_ST_LLVM_DIR}" + cmake -G Ninja -S llvm -B "${PTO_DSL_ST_LLVM_DIR}" \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + + ninja -C "${PTO_DSL_ST_LLVM_DIR}" + + - name: Save PTODSL LLVM build cache + if: steps.ptodsl-llvm-cache.outputs.cache-hit != 'true' + continue-on-error: true + uses: actions/cache/save@v4 + with: + path: ${{ env.PTO_DSL_ST_LLVM_DIR }} + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + + - name: Build PTODSL PTOAS + shell: bash + run: | + set -euo pipefail + rm -rf "${PTO_DSL_ST_BUILD_DIR}" "${PTO_DSL_ST_INSTALL_DIR}" + export CC=gcc + export CXX=g++ + LLVM_BUILD_DIR="${PTO_DSL_ST_LLVM_DIR}" \ + PTO_BUILD_DIR="${PTO_DSL_ST_BUILD_DIR}" \ + PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" \ + "${PTO_DSL_ST_PYTHON_BIN}" -m pip install . --no-build-isolation --no-deps --ignore-installed --prefix "${PTO_DSL_ST_INSTALL_DIR}" + - name: Run PTODSL DSL ST CI shell: bash run: | set -euo pipefail mkdir -p "${TILELANG_DSL_WORKSPACE}" - export LLVM_BUILD_DIR="${LLVM_DIR}" - export MLIR_PYTHON_ROOT="${MLIR_PYTHONPATH}" + export LLVM_BUILD_DIR="${PTO_DSL_ST_LLVM_DIR}" + export MLIR_PYTHON_ROOT="${PTO_DSL_ST_MLIR_PYTHONPATH}" + export PTO_INSTALL_DIR="${PTO_DSL_ST_INSTALL_DIR}" + export PTO_PYTHON_BUILD_ROOT="${PTO_DSL_ST_BUILD_DIR}/python" export PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" export PTO_PYTHON_BIN="${PTO_DSL_ST_PYTHON_BIN}" export PTOAS_PYTHON_SITE="${PTO_DSL_ST_PYTHON_SITE}" @@ -708,7 +762,7 @@ jobs: source "${ASCEND_HOME_PATH}/bin/setenv.bash" probe_ptodsl_python() { - PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_INSTALL_DIR}:${PTOAS_PYTHON_SITE}:${MLIR_PYTHONPATH}:${GITHUB_WORKSPACE}/build/python:${PYTHONPATH:-}" \ + PYTHONPATH="${GITHUB_WORKSPACE}/ptodsl:${PTO_DSL_ST_INSTALL_DIR}:${PTOAS_PYTHON_SITE}:${PTO_DSL_ST_MLIR_PYTHONPATH}:${PTO_DSL_ST_BUILD_DIR}/python:${PYTHONPATH:-}" \ "${PTO_DSL_ST_PYTHON_BIN}" - <<'PY' import sys import numpy @@ -736,12 +790,12 @@ jobs: cat "${PROBE_LOG}" ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ - PTOAS_BIN="${PTOAS_BIN}" \ + PTOAS_BIN="${PTO_DSL_ST_PTOAS_BIN}" \ scripts/sim_dsl.sh test/dsl-st \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st.log" ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ - PTOAS_BIN="${PTOAS_BIN}" \ + PTOAS_BIN="${PTO_DSL_ST_PTOAS_BIN}" \ scripts/sim_dsl.sh test/dsl-st/npu_a5 \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/ptodsl-dsl-st-npu-a5.log" From 8e0b127194ded528c5537f366bd7a026a3963de5 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 10:58:14 +0800 Subject: [PATCH 05/10] Rerun view lowering after tile expand --- include/PTO/Transforms/Passes.h | 2 + include/PTO/Transforms/Passes.td | 20 +- lib/PTO/Transforms/ExpandTileOp.cpp | 80 +----- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 263 ++----------------- lib/PTO/Transforms/PTOViewToMemref.cpp | 65 +++-- tools/ptoas/ptoas.cpp | 10 + 6 files changed, 94 insertions(+), 346 deletions(-) diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3de31a89bf..0a69cfaa10 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -73,6 +73,8 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); +std::unique_ptr +createPTOViewToMemrefPass(const PTOViewToMemrefOptions &options); std::unique_ptr createPTOValidateIntToPtrUsesPass(); std::unique_ptr createPTOMaterializeTileHandlesPass(); std::unique_ptr createPTOResolveBufferSelectPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 7c43e4af2a..aae679b825 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -507,19 +507,17 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu - pto.tile_valid_cols → same as above for v_col tensor_view family: - - pto.tensor_view_addr → traces through either - unrealized_conversion_cast → subview → reinterpret_cast or native - pto.partition_view → pto.make_tensor_view, then folds to the base memref - or to pto.castptr/pto.addptr on the base pointer + - pto.tensor_view_addr → traces through + unrealized_conversion_cast → subview → reinterpret_cast, then folds to + the base memref or to pto.castptr/pto.addptr on the base pointer - pto.get_tensor_view_dim → folded to arith.constant for static view sizes, or to the source size SSA operand for dynamic dims - pto.get_tensor_view_stride → folded to the lowered reinterpret_cast - stride, multiplied by the subview stride when needed, or to the native - make_tensor_view stride operand + stride, multiplied by the subview stride when needed - Dead unrealized_conversion_cast, memref.subview, memref.reinterpret_cast, - pto.partition_view, and pto.make_tensor_view ops exposed by folding are - cleaned up after the rewrite. + Dead unrealized_conversion_cast, memref.subview, and + memref.reinterpret_cast ops exposed by folding are cleaned up after the + rewrite. }]; let constructor = "mlir::pto::createFoldTileBufIntrinsicsPass()"; let options = [ @@ -648,6 +646,10 @@ def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> { }]; let constructor = "mlir::pto::createPTOViewToMemrefPass()"; + let options = [ + Option<"viewOnly", "view-only", "bool", /*default=*/"false", + "Only rerun structured tensor_view lowering without rewriting tile or compute surfaces"> + ]; let dependentDialects = [ "mlir::pto::PTODialect", diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index ab903af043..06a49b0437 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -82,7 +82,7 @@ namespace { // Four kinds of operands: // Tile — from TileBufType. dtype + shape + memorySpace + config // all participate in the specialization key (SpecKey). -// View — from TensorViewType / PartitionTensorViewType or MemRefType. +// View — from MemRefType (lowered TensorView/PartitionTensorView). // dtype, shape, strides, memorySpace, and optional explicit layout // participate in SpecKey because they affect template selection and // generated DMA parameters for tload/tstore. @@ -547,18 +547,6 @@ static void recordStaticSizes(ArrayRef inputs, out.push_back(getStaticIntOrDynamic(ofr)); } -static void recordStaticValues(ValueRange inputs, SmallVectorImpl &out) { - out.clear(); - out.reserve(inputs.size()); - for (Value value : inputs) { - int64_t dim = ShapedType::kDynamic; - if (getStaticIntFromValue(value, dim)) - out.push_back(dim); - else - out.push_back(ShapedType::kDynamic); - } -} - static SmallVector combineSubviewStrides(ArrayRef baseStrides, ArrayRef steps) { SmallVector result; @@ -622,40 +610,6 @@ static void populateViewShapeAndStrides(Value value, } } -static void populateTensorViewShapeAndStrides(Value value, - SmallVectorImpl &shape, - SmallVectorImpl &strides) { - if (!value) - return; - - Operation *def = value.getDefiningOp(); - if (auto partition = dyn_cast_or_null(def)) { - recordStaticValues(partition.getSizes(), shape); - SmallVector sourceShape; - populateTensorViewShapeAndStrides(partition.getSource(), sourceShape, - strides); - return; - } - - if (auto makeView = dyn_cast_or_null(def)) { - recordStaticValues(makeView.getShape(), shape); - recordStaticValues(makeView.getStrides(), strides); - return; - } - - Type ty = value.getType(); - if (auto partTy = dyn_cast(ty)) { - if (shape.empty()) - shape.assign(partTy.getShape().begin(), partTy.getShape().end()); - return; - } - if (auto tvTy = dyn_cast(ty)) { - if (shape.empty()) - shape.assign(tvTy.getShape().begin(), tvTy.getShape().end()); - return; - } -} - static std::optional buildOperandTypeInfo(Value value) { Type ty = value.getType(); // Tile operand — from TileBufType. @@ -683,37 +637,7 @@ static std::optional buildOperandTypeInfo(Value value) { return info; } - // View operand — from native TensorViewType / PartitionTensorViewType before - // PTOViewToMemref has rewritten the view to memref. - if (auto tvTy = dyn_cast(ty)) { - OperandTypeInfo info; - info.kind = OperandKind::View; - info.dtype = getDtypeString(tvTy.getElementType()); - if (info.dtype.empty()) - return std::nullopt; - info.viewMemorySpace = "gm"; - info.viewLayout = resolveViewLayout(value); - populateTensorViewShapeAndStrides(value, info.viewShape, info.viewStrides); - if (info.viewShape.empty()) - info.viewShape.assign(tvTy.getShape().begin(), tvTy.getShape().end()); - return info; - } - - if (auto partTy = dyn_cast(ty)) { - OperandTypeInfo info; - info.kind = OperandKind::View; - info.dtype = getDtypeString(partTy.getElementType()); - if (info.dtype.empty()) - return std::nullopt; - info.viewMemorySpace = "gm"; - info.viewLayout = resolveViewLayout(value); - populateTensorViewShapeAndStrides(value, info.viewShape, info.viewStrides); - if (info.viewShape.empty()) - info.viewShape.assign(partTy.getShape().begin(), partTy.getShape().end()); - return info; - } - - // View operand — from MemRefType (lowered PartitionTensorViewType). + // View operand — from MemRefType (lowered TensorView / PartitionTensorView). if (auto mrTy = dyn_cast(ty)) { OperandTypeInfo info; info.kind = OperandKind::View; diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index 1badfd006e..f87084939e 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -25,11 +25,10 @@ // For tile_buf intrinsics, the active VPTO path folds against materialized tile // handles produced by the shared tile-handle bridge (`pto.alloc_tile` or // `pto.materialize_tile`). -// For tensor_view intrinsics, the pass traces either through the lowered +// For tensor_view intrinsics, the pass traces through the lowered // unrealized_conversion_cast → memref.subview → memref.reinterpret_cast chain -// or through the native pto.partition_view → pto.make_tensor_view chain to fold -// directly to constants or SSA operands, without generating intermediate -// memref.dim / memref.extract_strided_metadata ops. +// to fold directly to constants or SSA operands, without generating +// intermediate memref.dim / memref.extract_strided_metadata ops. // //===----------------------------------------------------------------------===// @@ -223,113 +222,26 @@ static MemRefType getCanonicalMemRefTypeForTileBuf(pto::TileBufType tileTy) { AffineMap(), tileTy.getMemorySpace()); } -enum class ViewChainKind { - MemRef, - Native, -}; - struct ViewChain { - ViewChainKind kind = ViewChainKind::MemRef; - - // Lowered memref view chain. - UnrealizedConversionCastOp cast; memref::SubViewOp subview; memref::ReinterpretCastOp reinterpretCast; Value baseMemref; - - // Native tensor_view / partition_tensor_view chain. - pto::MakeTensorViewOp makeView; - pto::PartitionViewOp partitionView; }; -static bool validateNativeViewChain(pto::MakeTensorViewOp makeView, - pto::PartitionViewOp partitionView, - Operation *user) { - if (!makeView) { - user->emitError("FoldTileBufIntrinsics: native tensor_view must be " - "defined by pto.make_tensor_view"); - return false; - } - - size_t rank = makeView.getShape().size(); - if (makeView.getStrides().size() != rank) { - user->emitError("FoldTileBufIntrinsics: pto.make_tensor_view shape/stride " - "rank mismatch"); - return false; - } - - if (auto tvTy = dyn_cast(makeView.getResult().getType())) { - if (static_cast(tvTy.getRank()) != rank) { - user->emitError("FoldTileBufIntrinsics: pto.make_tensor_view result rank " - "does not match shape operands"); - return false; - } - } - - if (!partitionView) - return true; - - if (partitionView.getOffsets().size() != rank || - partitionView.getSizes().size() != rank) { - user->emitError("FoldTileBufIntrinsics: pto.partition_view rank must match " - "its source tensor_view rank"); - return false; - } - - if (auto partTy = dyn_cast( - partitionView.getResult().getType())) { - if (static_cast(partTy.getRank()) != rank) { - user->emitError("FoldTileBufIntrinsics: pto.partition_view result rank " - "does not match its operands"); - return false; - } - } - - return true; -} - static std::optional traceViewChain(Value tensorView, Operation *user) { Value view = tensorView; - UnrealizedConversionCastOp castOp; if (auto cast = view.getDefiningOp()) { - if (cast.getNumOperands() == 1 && cast.getNumResults() == 1) { - castOp = cast; + if (cast.getNumOperands() == 1 && cast.getNumResults() == 1) view = cast.getOperand(0); - } } if (!isa(view.getType())) { - if (auto partition = view.getDefiningOp()) { - auto makeView = - partition.getSource().getDefiningOp(); - if (!validateNativeViewChain(makeView, partition, user)) - return std::nullopt; - ViewChain chain; - chain.kind = ViewChainKind::Native; - chain.cast = castOp; - chain.makeView = makeView; - chain.partitionView = partition; - return chain; - } - - if (auto makeView = view.getDefiningOp()) { - if (!validateNativeViewChain(makeView, pto::PartitionViewOp(), user)) - return std::nullopt; - ViewChain chain; - chain.kind = ViewChainKind::Native; - chain.cast = castOp; - chain.makeView = makeView; - return chain; - } - - user->emitError("FoldTileBufIntrinsics: expected tensor_view to be defined " - "by a lowered memref.subview chain or native " - "pto.partition_view/pto.make_tensor_view chain, got ") - << (view.getDefiningOp() - ? view.getDefiningOp()->getName().getStringRef() - : StringRef("block argument")); + user->emitError("FoldTileBufIntrinsics: expected tensor_view to be lowered " + "to a memref.subview chain before folding, got ") + << (view.getDefiningOp() ? view.getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); return std::nullopt; } @@ -356,8 +268,6 @@ static std::optional traceViewChain(Value tensorView, } ViewChain chain; - chain.kind = ViewChainKind::MemRef; - chain.cast = castOp; chain.subview = subviewOp; chain.reinterpretCast = rcOp; chain.baseMemref = rcOp.getSource(); @@ -414,14 +324,6 @@ static bool isStaticIndexValue(OpFoldResult ofr, int64_t expected) { return getConstIndexValue(ofr, value) && value == expected; } -static SmallVector valuesToFoldResults(ValueRange values) { - SmallVector result; - result.reserve(values.size()); - for (Value value : values) - result.push_back(value); - return result; -} - static bool isAllStaticZero(ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (!isStaticIndexValue(ofr, 0)) @@ -476,104 +378,28 @@ static Value computeLinearOffset(OpBuilder &builder, Location loc, return rcPart ? rcPart : svPart; } -static int64_t getStaticNativeViewDim(ViewChain &chain, int64_t dimIdx) { - if (chain.partitionView) { - auto partTy = dyn_cast( - chain.partitionView.getResult().getType()); - if (partTy && partTy.getDimSize(dimIdx) != ShapedType::kDynamic) - return partTy.getDimSize(dimIdx); - return ShapedType::kDynamic; - } - - auto tvTy = - dyn_cast(chain.makeView.getResult().getType()); - if (tvTy && tvTy.getDimSize(dimIdx) != ShapedType::kDynamic) - return tvTy.getDimSize(dimIdx); - return ShapedType::kDynamic; -} - static unsigned getViewRank(ViewChain &chain) { - if (chain.kind == ViewChainKind::MemRef) - return cast(chain.subview.getType()).getRank(); - return chain.makeView.getShape().size(); + return cast(chain.subview.getType()).getRank(); } static std::optional buildTensorViewDimValue(OpBuilder &builder, Location loc, ViewChain &chain, - int64_t dimIdx, - Operation *user) { - if (chain.kind == ViewChainKind::MemRef) { - auto svTy = cast(chain.subview.getType()); - if (!svTy.isDynamicDim(dimIdx)) - return builder.create(loc, - svTy.getDimSize(dimIdx)); - return getValueOrCreateConstant(builder, loc, - chain.subview.getMixedSizes()[dimIdx]); - } - - int64_t staticDim = getStaticNativeViewDim(chain, dimIdx); - if (staticDim != ShapedType::kDynamic) - return builder.create(loc, staticDim); - - ValueRange sizes = chain.partitionView ? chain.partitionView.getSizes() - : chain.makeView.getShape(); - if (dimIdx < 0 || static_cast(dimIdx) >= sizes.size()) { - user->emitError("FoldTileBufIntrinsics: native tensor_view dim index out " - "of bounds"); - return std::nullopt; - } - return sizes[dimIdx]; + int64_t dimIdx) { + auto svTy = cast(chain.subview.getType()); + if (!svTy.isDynamicDim(dimIdx)) + return builder.create(loc, svTy.getDimSize(dimIdx)); + return getValueOrCreateConstant(builder, loc, + chain.subview.getMixedSizes()[dimIdx]); } static std::optional buildTensorViewStrideValue(OpBuilder &builder, Location loc, ViewChain &chain, - int64_t dimIdx, - Operation *user) { - if (chain.kind == ViewChainKind::MemRef) - return computeResultStride( - builder, loc, chain.reinterpretCast.getMixedStrides()[dimIdx], - chain.subview.getMixedStrides()[dimIdx]); - - ValueRange strides = chain.makeView.getStrides(); - if (dimIdx < 0 || static_cast(dimIdx) >= strides.size()) { - user->emitError("FoldTileBufIntrinsics: native tensor_view stride index " - "out of bounds"); - return std::nullopt; - } - return strides[dimIdx]; -} - -static Value computeNativeLinearOffset(OpBuilder &builder, Location loc, - ViewChain &chain) { - if (!chain.partitionView) - return Value(); - - SmallVector offsets = - valuesToFoldResults(chain.partitionView.getOffsets()); - SmallVector strides = - valuesToFoldResults(chain.makeView.getStrides()); - return computeLinearOffset(builder, loc, /*rcOffsets=*/{}, offsets, strides); -} - -static std::optional buildNativeTensorViewBasePtr(OpBuilder &builder, - Location loc, - ViewChain &chain, - pto::PtrType resultTy, - Operation *user) { - Value base = chain.makeView.getPtr(); - if (base.getType() == resultTy) - return base; - - if (!isa(base.getType())) { - user->emitError("FoldTileBufIntrinsics: native tensor_view_addr base must " - "be !pto.ptr, memref, or integer, got ") - << base.getType(); - return std::nullopt; - } - - return builder.create(loc, resultTy, base).getResult(); + int64_t dimIdx) { + return computeResultStride(builder, loc, + chain.reinterpretCast.getMixedStrides()[dimIdx], + chain.subview.getMixedStrides()[dimIdx]); } struct FoldTileBufIntrinsicsPass @@ -811,8 +637,7 @@ struct FoldTileBufIntrinsicsPass builder.setInsertionPoint(dimOp); std::optional replacement = - buildTensorViewDimValue(builder, dimOp.getLoc(), *chain, dimIdx, - dimOp.getOperation()); + buildTensorViewDimValue(builder, dimOp.getLoc(), *chain, dimIdx); if (!replacement) return signalPassFailure(); @@ -843,8 +668,7 @@ struct FoldTileBufIntrinsicsPass builder.setInsertionPoint(strideOp); std::optional replacement = buildTensorViewStrideValue( - builder, strideOp.getLoc(), *chain, dimIdx, - strideOp.getOperation()); + builder, strideOp.getLoc(), *chain, dimIdx); if (!replacement) return signalPassFailure(); @@ -865,12 +689,6 @@ struct FoldTileBufIntrinsicsPass if (!resultPtrType) { if (auto resultMemrefType = dyn_cast(addrOp.getDst().getType())) { - if (chain->kind == ViewChainKind::Native) { - addrOp.emitError("FoldTileBufIntrinsics: native tensor_view_addr " - "cannot fold to memref without first lowering " - "the view to memref"); - return signalPassFailure(); - } Value base = chain->baseMemref; if (base.getType() != resultMemrefType) addrOp.getDst().setType(cast(base.getType())); @@ -886,24 +704,13 @@ struct FoldTileBufIntrinsicsPass Value linearOffset; Value basePtr; - if (chain->kind == ViewChainKind::MemRef) { - linearOffset = - computeLinearOffset(builder, addrOp.getLoc(), - chain->reinterpretCast.getMixedOffsets(), - chain->subview.getMixedOffsets(), - chain->reinterpretCast.getMixedStrides()); - basePtr = builder.create( - addrOp.getLoc(), resultPtrType, chain->baseMemref); - } else { - std::optional nativeBase = buildNativeTensorViewBasePtr( - builder, addrOp.getLoc(), *chain, resultPtrType, - addrOp.getOperation()); - if (!nativeBase) - return signalPassFailure(); - basePtr = *nativeBase; - linearOffset = computeNativeLinearOffset(builder, addrOp.getLoc(), - *chain); - } + linearOffset = + computeLinearOffset(builder, addrOp.getLoc(), + chain->reinterpretCast.getMixedOffsets(), + chain->subview.getMixedOffsets(), + chain->reinterpretCast.getMixedStrides()); + basePtr = builder.create( + addrOp.getLoc(), resultPtrType, chain->baseMemref); Value replacement = linearOffset @@ -944,20 +751,6 @@ struct FoldTileBufIntrinsicsPass op->erase(); } - while (true) { - SmallVector deadViewOps; - func.walk([&](Operation *op) { - if ((isa(op) || - isa(op)) && - op->use_empty()) - deadViewOps.push_back(op); - }); - if (deadViewOps.empty()) - break; - for (auto *op : llvm::reverse(deadViewOps)) - op->erase(); - } - eraseDeadAllocTileOps(func); } }; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index e86a0a3ad8..e7999eaa83 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1599,12 +1599,18 @@ static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx struct PTOViewToMemrefPass : public mlir::pto::impl::PTOViewToMemrefBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) + using mlir::pto::impl::PTOViewToMemrefBase< + PTOViewToMemrefPass>::PTOViewToMemrefBase; void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); for (auto func : mod.getOps()) { + if (viewOnly) { + if (func.isExternal()) + continue; + } else { // ------------------------------------------------------------------ // Stage 0: ensure inttoptr values remain scalar-load/store only. // ------------------------------------------------------------------ @@ -2097,6 +2103,7 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } + } // ------------------------------------------------------------------ // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast @@ -2171,25 +2178,27 @@ struct PTOViewToMemrefPass rewriter.replaceOp(op, rc.getResult()); } - // ------------------------------------------------------------------ - // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim - // ------------------------------------------------------------------ - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + if (!viewOnly) { + // ------------------------------------------------------------------ + // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim + // ------------------------------------------------------------------ + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; // leave it to later passes if it hasn't been lowered yet + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet - Value dimIdx = op.getDimIndex(); - Value dim = rewriter.create(loc, view, dimIdx); - rewriter.replaceOp(op, dim); + Value dimIdx = op.getDimIndex(); + Value dim = rewriter.create(loc, view, dimIdx); + rewriter.replaceOp(op, dim); + } } // ------------------------------------------------------------------ @@ -2220,6 +2229,9 @@ struct PTOViewToMemrefPass return; } + if (viewOnly) + continue; + // ------------------------------------------------------------------ // Stage 1.5: Lower pto.get_tensor_view_stride -> strided memref metadata // ------------------------------------------------------------------ @@ -2378,7 +2390,11 @@ struct PTOViewToMemrefPass } } - // Clean up: addptr should be folded into make_tensor_view. + // Clean up dead addptr after folding the view/scalar patterns above. + // Live addptr users are legal low-level pointer arithmetic on the VPTO + // path (for example helper-local DMA pointer bumps that appear after + // ExpandTileOp + inline). Leave them in place so a second + // PTOViewToMemref run stays idempotent over already-lowered helper IR. DefaultInlineVector addPtrs; func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); bool changed = true; @@ -2394,13 +2410,9 @@ struct PTOViewToMemrefPass } } } - for (auto *op : addPtrs) { - if (!op) - continue; - op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); - signalPassFailure(); - return; - } + + if (viewOnly) + continue; // ------------------------------------------------------------------ // Stage 3: Rewrite Compute Ops @@ -4307,5 +4319,10 @@ std::unique_ptr createPTOViewToMemrefPass() { return std::make_unique(); } +std::unique_ptr +createPTOViewToMemrefPass(const PTOViewToMemrefOptions &options) { + return std::make_unique(options); +} + } // namespace pto } // namespace mlir diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index d5ba2b1319..a310fe8b31 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1769,6 +1769,16 @@ static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); + pto::PTOViewToMemrefOptions viewOnlyRerunOpts; + viewOnlyRerunOpts.viewOnly = true; + // ExpandTileOp materializes fresh TileLang helper IR after the shared + // mainline has already lowered the original module's tensor_view surface. + // Re-run the shared view lowering here so helper-local + // pto.make_tensor_view/pto.partition_view chains do not leak into + // FoldTileBufIntrinsics. + kernelModulePM.addPass(pto::createPTOViewToMemrefPass(viewOnlyRerunOpts)); + kernelModulePM.addPass(mlir::createCanonicalizerPass()); + kernelModulePM.addPass(mlir::createCSEPass()); kernelModulePM.addNestedPass( pto::createFoldTileBufIntrinsicsPass("shape-only")); if (enableA5VPTOPostLoweringFusionLifecycle) { From e4776bf4ed4a576e8ec3b7fc9e96bd1eb9b3ab90 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 11:04:50 +0800 Subject: [PATCH 06/10] Trim fold tile buf cleanup helpers --- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 92 ++++++++------------ 1 file changed, 34 insertions(+), 58 deletions(-) diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index f87084939e..d6ae21050d 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -309,24 +309,13 @@ static Value getValueOrCreateConstant(OpBuilder &builder, Location loc, return builder.create(loc, intAttr.getInt()); } -static bool getConstIndexValue(OpFoldResult ofr, int64_t &out) { - if (auto value = dyn_cast(ofr)) - return getConstIndexValue(value, out); - auto intAttr = dyn_cast(cast(ofr)); - if (!intAttr) - return false; - out = intAttr.getInt(); - return true; -} - -static bool isStaticIndexValue(OpFoldResult ofr, int64_t expected) { - int64_t value = 0; - return getConstIndexValue(ofr, value) && value == expected; -} - static bool isAllStaticZero(ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { - if (!isStaticIndexValue(ofr, 0)) + auto attr = dyn_cast(ofr); + if (!attr) + return false; + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() != 0) return false; } return true; @@ -335,8 +324,11 @@ static bool isAllStaticZero(ArrayRef ofrs) { static Value computeResultStride(OpBuilder &builder, Location loc, OpFoldResult rcStride, OpFoldResult svStride) { - if (isStaticIndexValue(svStride, 1)) - return getValueOrCreateConstant(builder, loc, rcStride); + if (auto attr = dyn_cast(svStride)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 1) + return getValueOrCreateConstant(builder, loc, rcStride); + } Value lhs = getValueOrCreateConstant(builder, loc, rcStride); Value rhs = getValueOrCreateConstant(builder, loc, svStride); @@ -356,8 +348,11 @@ static Value computeLinearOffset(OpBuilder &builder, Location loc, Value svPart; if (!svAllZero) { for (auto [svOffset, rcStride] : llvm::zip(svOffsets, rcStrides)) { - if (isStaticIndexValue(svOffset, 0)) - continue; + if (auto attr = dyn_cast(svOffset)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 0) + continue; + } Value off = getValueOrCreateConstant(builder, loc, svOffset); Value stride = getValueOrCreateConstant(builder, loc, rcStride); @@ -378,30 +373,6 @@ static Value computeLinearOffset(OpBuilder &builder, Location loc, return rcPart ? rcPart : svPart; } -static unsigned getViewRank(ViewChain &chain) { - return cast(chain.subview.getType()).getRank(); -} - -static std::optional buildTensorViewDimValue(OpBuilder &builder, - Location loc, - ViewChain &chain, - int64_t dimIdx) { - auto svTy = cast(chain.subview.getType()); - if (!svTy.isDynamicDim(dimIdx)) - return builder.create(loc, svTy.getDimSize(dimIdx)); - return getValueOrCreateConstant(builder, loc, - chain.subview.getMixedSizes()[dimIdx]); -} - -static std::optional buildTensorViewStrideValue(OpBuilder &builder, - Location loc, - ViewChain &chain, - int64_t dimIdx) { - return computeResultStride(builder, loc, - chain.reinterpretCast.getMixedStrides()[dimIdx], - chain.subview.getMixedStrides()[dimIdx]); -} - struct FoldTileBufIntrinsicsPass : public pto::impl::FoldTileBufIntrinsicsBase { using FoldTileBufIntrinsicsBase::FoldTileBufIntrinsicsBase; @@ -627,8 +598,8 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - unsigned rank = getViewRank(*chain); - if (dimIdx < 0 || static_cast(dimIdx) >= rank) { + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { dimOp.emitError( "FoldTileBufIntrinsics: get_tensor_view_dim dim index out of " "bounds"); @@ -636,12 +607,17 @@ struct FoldTileBufIntrinsicsPass } builder.setInsertionPoint(dimOp); - std::optional replacement = - buildTensorViewDimValue(builder, dimOp.getLoc(), *chain, dimIdx); - if (!replacement) - return signalPassFailure(); + Value replacement; + if (!svTy.isDynamicDim(dimIdx)) { + replacement = + builder.create(dimOp.getLoc(), + svTy.getDimSize(dimIdx)); + } else { + replacement = getValueOrCreateConstant( + builder, dimOp.getLoc(), chain->subview.getMixedSizes()[dimIdx]); + } - dimOp.getResult().replaceAllUsesWith(*replacement); + dimOp.getResult().replaceAllUsesWith(replacement); dimOp.erase(); } @@ -658,8 +634,8 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } - unsigned rank = getViewRank(*chain); - if (dimIdx < 0 || static_cast(dimIdx) >= rank) { + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { strideOp.emitError( "FoldTileBufIntrinsics: get_tensor_view_stride dim index out of " "bounds"); @@ -667,12 +643,12 @@ struct FoldTileBufIntrinsicsPass } builder.setInsertionPoint(strideOp); - std::optional replacement = buildTensorViewStrideValue( - builder, strideOp.getLoc(), *chain, dimIdx); - if (!replacement) - return signalPassFailure(); + Value replacement = computeResultStride( + builder, strideOp.getLoc(), + chain->reinterpretCast.getMixedStrides()[dimIdx], + chain->subview.getMixedStrides()[dimIdx]); - strideOp.getResult().replaceAllUsesWith(*replacement); + strideOp.getResult().replaceAllUsesWith(replacement); strideOp.erase(); } } From 85a9627b2e45c052a9599e1e4acfba9ba7b89db9 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 14:32:49 +0800 Subject: [PATCH 07/10] Fix PTODSL VPTO container compile and CI --- .github/workflows/ci_sim.yml | 9 +++ ptodsl/tests/test_ptoas_frontend_verify.py | 79 +++++++++++++++++++++- tools/ptoas/driver.cpp | 18 ++++- 3 files changed, 103 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 8babd6ab24..0c227afb60 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -721,6 +721,15 @@ jobs: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DPython_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ + -DPython3_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ + -DPython_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ + -DPython3_FIND_STRATEGY=LOCATION \ + -DPython_FIND_STRATEGY=LOCATION \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DPython_FIND_VIRTUALENV=ONLY \ + -Dpybind11_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" diff --git a/ptodsl/tests/test_ptoas_frontend_verify.py b/ptodsl/tests/test_ptoas_frontend_verify.py index 608bb9a0e2..7b8a02df0c 100644 --- a/ptodsl/tests/test_ptoas_frontend_verify.py +++ b/ptodsl/tests/test_ptoas_frontend_verify.py @@ -151,14 +151,24 @@ def run_ptoas_frontend_verify(ptoas_bin: Path, mlir_text: str, label: str) -> li return frontend_texts -def run_ptoas_frontend_verify_whole(ptoas_bin: Path, mlir_text: str, label: str) -> str: +def run_ptoas_frontend_verify_whole( + ptoas_bin: Path, + mlir_text: str, + label: str, + *, + extra_ptoas_args: list[str] | None = None, +) -> str: with tempfile.NamedTemporaryFile("w", suffix=".mlir", delete=False, encoding="utf-8") as handle: handle.write(mlir_text) input_path = Path(handle.name) try: + cmd = [str(ptoas_bin), str(input_path)] + if extra_ptoas_args: + cmd.extend(extra_ptoas_args) + cmd.extend(["--emit-pto-ir", "-o", "-"]) result = subprocess.run( - [str(ptoas_bin), str(input_path), "--emit-pto-ir", "-o", "-"], + cmd, capture_output=True, text=True, check=False, @@ -174,6 +184,19 @@ def run_ptoas_frontend_verify_whole(ptoas_bin: Path, mlir_text: str, label: str) return result.stdout +def expect_no_raw_partition_tensor_view(frontend_text: str, label: str) -> None: + expect( + "!pto.partition_tensor_view" not in frontend_text, + f"{label} should not leak raw !pto.partition_tensor_view into PTOAS frontend output.\n" + f"frontend output:\n{frontend_text}", + ) + expect( + "memref.subview" in frontend_text or "memref.reinterpret_cast" in frontend_text, + f"{label} should materialize memref-backed view lowering in PTOAS frontend output.\n" + f"frontend output:\n{frontend_text}", + ) + + def run_ptoas_frontend_expect_failure( ptoas_bin: Path, mlir_text: str, @@ -223,6 +246,21 @@ def host_vec_copy( pto.tile.store(o_tile, out) +@pto.jit(target="a5", mode="explicit", insert_sync=False) +def explicit_addr_vec_copy( + A_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), +): + a_view = pto.make_tensor_view(A_ptr, shape=[1, 1, 1, 1, 64], strides=[64, 64, 64, 64, 1]) + o_view = pto.make_tensor_view(O_ptr, shape=[1, 1, 1, 1, 64], strides=[64, 64, 64, 64, 1]) + a_tile = pto.alloc_tile(shape=[1, 64], dtype=pto.f32, addr=0, valid_shape=[1, 64], blayout="RowMajor") + o_tile = pto.alloc_tile(shape=[1, 64], dtype=pto.f32, addr=2048, valid_shape=[1, 64], blayout="RowMajor") + part = pto.partition_view(a_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, 64]) + out = pto.partition_view(o_view, offsets=[0, 0, 0, 0, 0], sizes=[1, 1, 1, 1, 64]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + @pto.simt def simt_gm_memory_core_body(gm: pto.ptr(pto.i32, "gm")): tx = pto.get_tid_x() @@ -339,6 +377,43 @@ def main() -> None: "pto.tload" in simple_frontend_text and "pto.tstore" in simple_frontend_text, "host_vec_copy frontend verification output should keep the tile IO contract visible", ) + simple_whole_frontend_text = run_ptoas_frontend_verify_whole( + ptoas_bin, + simple_text, + "host_vec_copy PTODSL whole-container artifact", + ) + expect( + "func.func @host_vec_copy" in simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification should preserve the kernel symbol", + ) + expect( + "pto.tload" in simple_whole_frontend_text and "pto.tstore" in simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification should keep the tile IO contract visible", + ) + expect_no_raw_partition_tensor_view( + simple_whole_frontend_text, + "host_vec_copy whole-container frontend verification", + ) + + explicit_whole_text = explicit_addr_vec_copy.compile().mlir_text() + explicit_whole_frontend_text = run_ptoas_frontend_verify_whole( + ptoas_bin, + explicit_whole_text, + "explicit_addr_vec_copy PTODSL whole-container artifact", + extra_ptoas_args=["--pto-level=level3"], + ) + expect( + "func.func @explicit_addr_vec_copy" in explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification should preserve the kernel symbol", + ) + expect( + "pto.tload" in explicit_whole_frontend_text and "pto.tstore" in explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification should keep the tile IO contract visible", + ) + expect_no_raw_partition_tensor_view( + explicit_whole_frontend_text, + "explicit_addr_vec_copy whole-container frontend verification", + ) simt_gm_memory_text = simt_gm_memory_core_kernel.compile().mlir_text() simt_frontend_texts = run_ptoas_frontend_verify( diff --git a/tools/ptoas/driver.cpp b/tools/ptoas/driver.cpp index b8de95aa2c..fcb9197388 100644 --- a/tools/ptoas/driver.cpp +++ b/tools/ptoas/driver.cpp @@ -937,12 +937,28 @@ LogicalResult EmitCBackendJob::run(PTOASContext &context) { } LogicalResult VPTOBackendJob::run(PTOASContext &context) { + OwningOpRef singleChildJobModule; + OwningOpRef *compileUnit = &module; ModuleOp op = module.get(); op->setAttr("pto.backend", StringAttr::get(op.getContext(), "vpto")); + SmallVector children(op.getOps()); + if (children.size() == 1 && isBackendPartitionedContainer(op)) { + FailureOr> jobModuleOr = + buildBackendChildCompileUnit(op, children.front()); + if (failed(jobModuleOr)) + return failure(); + singleChildJobModule = std::move(*jobModuleOr); + singleChildJobModule.get()->setAttr( + "pto.backend", + StringAttr::get(singleChildJobModule.get()->getContext(), "vpto")); + compileUnit = &singleChildJobModule; + op = singleChildJobModule.get(); + } + bool emitHostStub = hasPTOEntry(op); if (mlir::pto::compilePTOASModule( - module, context, mlir::pto::PTOBackend::VPTO, result, + *compileUnit, context, mlir::pto::PTOBackend::VPTO, result, emitHostStub) != 0) return failure(); if (result.kind == mlir::pto::PTOASCompileResultKind::Text) From d53026f85ddaafff74a96840c0a120bc215a78e4 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 14:56:03 +0800 Subject: [PATCH 08/10] Fix ci-sim LLVM Python setup order --- .github/workflows/ci_sim.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 0c227afb60..659ae5bfaa 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -284,10 +284,10 @@ jobs: -DBUILD_SHARED_LIBS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ - -DPython_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ - -Dpybind11_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" \ - -Dnanobind_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" \ + -DPython3_EXECUTABLE=python3 \ + -DPython_EXECUTABLE=python3 \ + -Dpybind11_DIR="$(python3 -m pybind11 --cmakedir)" \ + -Dnanobind_DIR="$(python3 -m nanobind --cmake_dir)" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" From 551d75ebfd3e1e7f8d66121398a1cd84a5c14e50 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 16:59:34 +0800 Subject: [PATCH 09/10] Fix PTODSL CI and VPTO debug regressions --- .github/workflows/ci_sim.yml | 39 ++++++++++++++++--- lib/PTO/Transforms/PTOViewToMemref.cpp | 22 +++++++++-- lib/PTO/Transforms/VPTONormalizeContainer.cpp | 14 +++++++ tools/ptoas/driver.cpp | 15 ++++--- tools/ptoas/ptoas.cpp | 2 - 5 files changed, 75 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci_sim.yml b/.github/workflows/ci_sim.yml index 659ae5bfaa..78a8133781 100644 --- a/.github/workflows/ci_sim.yml +++ b/.github/workflows/ci_sim.yml @@ -575,6 +575,7 @@ jobs: probe_ptodsl_runtime_python() { TORCH_DEVICE_BACKEND_AUTOLOAD=0 "$1" - <<'PY' import sys + import nanobind import numpy import pybind11 import yaml @@ -586,6 +587,7 @@ jobs: print("torch", torch.__version__) print("torch_npu", getattr(torch_npu, "__version__", "unknown")) print("numpy", numpy.__version__) + print("nanobind", getattr(nanobind, "__version__", "unknown")) print("pybind11", pybind11.__version__) PY } @@ -597,6 +599,7 @@ jobs: deps = [ ("setuptools", "setuptools"), ("wheel", "wheel"), + ("nanobind", "nanobind"), ("numpy", "numpy"), ("ml_dtypes", "ml-dtypes"), ("yaml", "PyYAML"), @@ -704,7 +707,7 @@ jobs: uses: actions/cache/restore@v4 with: path: ${{ env.PTO_DSL_ST_LLVM_DIR }} - key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v2 - name: Build PTODSL LLVM/MLIR if: steps.ptodsl-llvm-cache.outputs.cache-hit != 'true' @@ -714,6 +717,21 @@ jobs: cd "${LLVM_ROOT}" export CC=gcc export CXX=g++ + PTODSL_PYTHON_PREFIX="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" + export VIRTUAL_ENV="${PTODSL_PYTHON_PREFIX}" + export PATH="${PTODSL_PYTHON_PREFIX}/bin:${PATH}" + PYBIND11_CMAKE_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" + NANOBIND_CMAKE_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" + PYTHON_INCLUDE_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" - <<'PY' + import sysconfig + print(sysconfig.get_path("include")) + PY + )" + PYTHON_LIBRARY="$("${PTO_DSL_ST_PYTHON_BIN}" - <<'PY' + import sysconfig + print(sysconfig.get_config_var("LIBDIR") + "/" + sysconfig.get_config_var("LDLIBRARY")) + PY + )" rm -rf "${PTO_DSL_ST_LLVM_DIR}" cmake -G Ninja -S llvm -B "${PTO_DSL_ST_LLVM_DIR}" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ @@ -722,14 +740,23 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ -DPython_EXECUTABLE="${PTO_DSL_ST_PYTHON_BIN}" \ - -DPython3_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ - -DPython_ROOT_DIR="$(dirname "$(dirname "${PTO_DSL_ST_PYTHON_BIN}")")" \ + -DPython3_ROOT_DIR="${PTODSL_PYTHON_PREFIX}" \ + -DPython_ROOT_DIR="${PTODSL_PYTHON_PREFIX}" \ + -DPython3_INCLUDE_DIR="${PYTHON_INCLUDE_DIR}" \ + -DPython_INCLUDE_DIR="${PYTHON_INCLUDE_DIR}" \ + -DPython3_LIBRARY="${PYTHON_LIBRARY}" \ + -DPython_LIBRARY="${PYTHON_LIBRARY}" \ -DPython3_FIND_STRATEGY=LOCATION \ -DPython_FIND_STRATEGY=LOCATION \ -DPython3_FIND_VIRTUALENV=ONLY \ -DPython_FIND_VIRTUALENV=ONLY \ - -Dpybind11_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m pybind11 --cmakedir)" \ - -Dnanobind_DIR="$("${PTO_DSL_ST_PYTHON_BIN}" -m nanobind --cmake_dir)" \ + -DPython3_FIND_REGISTRY=NEVER \ + -DPython_FIND_REGISTRY=NEVER \ + -DPython3_FIND_FRAMEWORK=NEVER \ + -DPython_FIND_FRAMEWORK=NEVER \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -Dnanobind_DIR="${NANOBIND_CMAKE_DIR}" \ + -DCMAKE_PREFIX_PATH="${PYBIND11_CMAKE_DIR};${NANOBIND_CMAKE_DIR};${PTODSL_PYTHON_PREFIX}" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_TARGETS_TO_BUILD="host" @@ -741,7 +768,7 @@ jobs: uses: actions/cache/save@v4 with: path: ${{ env.PTO_DSL_ST_LLVM_DIR }} - key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v1 + key: llvm-build-${{ steps.llvm-cache-key.outputs.sha }}-assert-${{ env.PTO_DSL_ST_PYTHON_TAG }}-v2 - name: Build PTODSL PTOAS shell: bash diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index e7999eaa83..158bd9e42f 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1592,6 +1592,19 @@ static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx return success(); } +static bool hasStructuredViewSurface(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + // ============================================================================= // The Pass Implementation // ============================================================================= @@ -1608,7 +1621,7 @@ struct PTOViewToMemrefPass for (auto func : mod.getOps()) { if (viewOnly) { - if (func.isExternal()) + if (func.isExternal() || !hasStructuredViewSurface(func)) continue; } else { // ------------------------------------------------------------------ @@ -2224,14 +2237,15 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } + + if (viewOnly) + continue; + if (failed(reconcileFusionRegionResultTypes(func))) { signalPassFailure(); return; } - if (viewOnly) - continue; - // ------------------------------------------------------------------ // Stage 1.5: Lower pto.get_tensor_view_stride -> strided memref metadata // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/VPTONormalizeContainer.cpp b/lib/PTO/Transforms/VPTONormalizeContainer.cpp index d5b22af3d6..db69b3c588 100644 --- a/lib/PTO/Transforms/VPTONormalizeContainer.cpp +++ b/lib/PTO/Transforms/VPTONormalizeContainer.cpp @@ -29,6 +29,20 @@ static bool isVPTOKernelSubmodule(ModuleOp module) { } static LogicalResult verifyNormalizedVPTOContainer(ModuleOp module) { + if (!isVPTOKernelSubmodule(module)) { + bool hasChildModules = false; + for (Operation &op : module.getBodyRegion().front().getOperations()) { + if (isa(op)) { + hasChildModules = true; + break; + } + } + if (!hasChildModules) { + return module.emitError() + << "expected VPTO kernel submodule to carry 'pto.kernel_kind'"; + } + } + bool hasChildModules = false; for (Operation &op : module.getBodyRegion().front().getOperations()) { auto child = dyn_cast(op); diff --git a/tools/ptoas/driver.cpp b/tools/ptoas/driver.cpp index fcb9197388..31a0afb44b 100644 --- a/tools/ptoas/driver.cpp +++ b/tools/ptoas/driver.cpp @@ -280,6 +280,12 @@ static bool isBackendPartitionedContainer(ModuleOp module) { [](Operation &op) { return isa(op); }); } +static bool isDebugIROutputRequested() { + return mlir::pto::emitMlirIR || mlir::pto::emitVPTO || + mlir::pto::emitVPTOLLVMDialect || mlir::pto::ptoPrintSeamIR || + !mlir::pto::ptoSeamIRFile.empty(); +} + static SmallVector collectImportedPeerNames(ModuleOp module) { SmallVector names; module.walk([&](pto::ImportReservedBufferOp importOp) { @@ -943,7 +949,9 @@ LogicalResult VPTOBackendJob::run(PTOASContext &context) { op->setAttr("pto.backend", StringAttr::get(op.getContext(), "vpto")); SmallVector children(op.getOps()); - if (children.size() == 1 && isBackendPartitionedContainer(op)) { + if (!isDebugIROutputRequested() && children.size() == 1 && + isBackendPartitionedContainer(op) && + children.front()->hasAttr(mlir::pto::FunctionKernelKindAttr::name)) { FailureOr> jobModuleOr = buildBackendChildCompileUnit(op, children.front()); if (failed(jobModuleOr)) @@ -1052,10 +1060,7 @@ static LogicalResult resolveSingleBackend( } SmallVector children(module.getOps()); - bool debugIROutputRequested = - mlir::pto::emitMlirIR || mlir::pto::emitVPTO || mlir::pto::ptoPrintSeamIR || - !mlir::pto::ptoSeamIRFile.empty(); - if (!debugIROutputRequested && children.size() > 1) { + if (!isDebugIROutputRequested() && children.size() > 1) { if (!isBackendPartitionedContainer(module)) { llvm::errs() << "Error: mixed pto.backend fatobj mode expects either a " "single module or an outer module containing only child " diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index a310fe8b31..ae5f23a941 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1777,8 +1777,6 @@ static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, // pto.make_tensor_view/pto.partition_view chains do not leak into // FoldTileBufIntrinsics. kernelModulePM.addPass(pto::createPTOViewToMemrefPass(viewOnlyRerunOpts)); - kernelModulePM.addPass(mlir::createCanonicalizerPass()); - kernelModulePM.addPass(mlir::createCSEPass()); kernelModulePM.addNestedPass( pto::createFoldTileBufIntrinsicsPass("shape-only")); if (enableA5VPTOPostLoweringFusionLifecycle) { From 1d12017558da9051936314978ce2ac84efe99a2a Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 2 Jul 2026 20:28:27 +0800 Subject: [PATCH 10/10] Fix PTODSL VPTO tile view lowering --- lib/PTO/Transforms/ExpandTileOp.cpp | 97 +++++++- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 165 ++++++++++++-- lib/PTO/Transforms/PTOViewToMemref.cpp | 85 ++++++- lib/PTO/Transforms/VPTOPtrNormalize.cpp | 210 +++++++++++++++++- test/dsl-st/npu_a5/tcolexpand.py | 5 + test/dsl-st/npu_a5/tcolsum.py | 5 + test/dsl-st/npu_a5/tload_store.py | 3 + test/dsl-st/predicate_pack.py | 49 ++-- .../vpto/predicate_memref_ptr_normalize.pto | 49 ++++ test/vpto/scripts/run_host_vpto_validation.sh | 88 +++++++- tools/ptoas/ptoas.cpp | 14 +- 11 files changed, 704 insertions(+), 66 deletions(-) create mode 100644 test/lit/vpto/predicate_memref_ptr_normalize.pto diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 06a49b0437..e12acd0e15 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -283,6 +283,11 @@ static std::string getMemorySpaceString(MemRefType mrTy) { return msAttr ? stringifyMemorySpace(msAttr.getAddressSpace()) : "gm"; } +static std::string getMemorySpaceString(pto::PtrType ptrTy) { + auto msAttr = ptrTy.getMemorySpace(); + return msAttr ? stringifyMemorySpace(msAttr.getAddressSpace()) : "gm"; +} + static std::string getBLayoutString(int32_t blayout) { if (blayout == static_cast(pto::BLayout::ColMajor)) return "col_major"; @@ -320,6 +325,11 @@ static std::optional resolveViewLayout(Value value) { def = value.getDefiningOp(); continue; } + if (auto partition = dyn_cast(def)) { + value = partition.getSource(); + def = value.getDefiningOp(); + continue; + } if (auto cast = dyn_cast(def)) { value = cast.getSource(); def = value.getDefiningOp(); @@ -547,6 +557,14 @@ static void recordStaticSizes(ArrayRef inputs, out.push_back(getStaticIntOrDynamic(ofr)); } +static void recordStaticValues(ValueRange inputs, SmallVectorImpl &out) { + SmallVector folded; + folded.reserve(inputs.size()); + for (Value value : inputs) + folded.push_back(value); + recordStaticSizes(folded, out); +} + static SmallVector combineSubviewStrides(ArrayRef baseStrides, ArrayRef steps) { SmallVector result; @@ -580,6 +598,15 @@ static void populateViewShapeAndStrides(Value value, return; } + if (auto partition = value.getDefiningOp()) { + populateViewShapeAndStrides(partition.getSource(), shape, strides); + SmallVector partitionShape; + recordStaticValues(partition.getSizes(), partitionShape); + if (!partitionShape.empty()) + shape = partitionShape; + return; + } + if (auto reinterpret = value.getDefiningOp()) { if (shape.empty()) { SmallVector reinterpretShape; @@ -592,6 +619,14 @@ static void populateViewShapeAndStrides(Value value, return; } + if (auto makeView = value.getDefiningOp()) { + if (shape.empty()) + recordStaticValues(makeView.getShape(), shape); + if (strides.empty()) + recordStaticValues(makeView.getStrides(), strides); + return; + } + if (auto cast = value.getDefiningOp()) { populateViewShapeAndStrides(cast.getSource(), shape, strides); return; @@ -610,6 +645,44 @@ static void populateViewShapeAndStrides(Value value, } } +static std::string getViewMemorySpaceString(Value value) { + if (!value) + return "gm"; + if (auto memrefTy = dyn_cast(value.getType())) + return getMemorySpaceString(memrefTy); + if (auto partition = value.getDefiningOp()) + return getViewMemorySpaceString(partition.getSource()); + if (auto makeView = value.getDefiningOp()) { + Value base = makeView.getPtr(); + if (auto ptrTy = dyn_cast(base.getType())) + return getMemorySpaceString(ptrTy); + } + return "gm"; +} + +static Type getViewElementType(Type ty) { + if (auto memrefTy = dyn_cast(ty)) + return memrefTy.getElementType(); + if (auto tensorViewTy = dyn_cast(ty)) + return tensorViewTy.getElementType(); + if (auto partitionTy = dyn_cast(ty)) + return partitionTy.getElementType(); + return {}; +} + +static void populateViewShapeFromType(Type ty, SmallVectorImpl &shape) { + if (auto memrefTy = dyn_cast(ty)) { + shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); + return; + } + if (auto tensorViewTy = dyn_cast(ty)) { + shape.assign(tensorViewTy.getShape().begin(), tensorViewTy.getShape().end()); + return; + } + if (auto partitionTy = dyn_cast(ty)) + shape.assign(partitionTy.getShape().begin(), partitionTy.getShape().end()); +} + static std::optional buildOperandTypeInfo(Value value) { Type ty = value.getType(); // Tile operand — from TileBufType. @@ -637,25 +710,31 @@ static std::optional buildOperandTypeInfo(Value value) { return info; } - // View operand — from MemRefType (lowered TensorView / PartitionTensorView). - if (auto mrTy = dyn_cast(ty)) { + // View operand — either already lowered to MemRefType, or still in the raw + // PTODSL tensor_view / partition_tensor_view form when tile-buffer helper ABI + // forced PTOViewToMemref to preserve the caller before ExpandTileOp. + if (isa(ty)) { OperandTypeInfo info; info.kind = OperandKind::View; - info.dtype = getDtypeString(mrTy.getElementType()); + info.dtype = getDtypeString(getViewElementType(ty)); if (info.dtype.empty()) return std::nullopt; - info.viewMemorySpace = getMemorySpaceString(mrTy); + info.viewMemorySpace = getViewMemorySpaceString(value); info.viewLayout = resolveViewLayout(value); populateViewShapeAndStrides(value, info.viewShape, info.viewStrides); if (info.viewShape.empty()) - info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); + populateViewShapeFromType(ty, info.viewShape); if (info.viewStrides.empty()) { - int64_t offset = ShapedType::kDynamic; - if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( - mrTy, info.viewStrides, offset))) { - // strides populated — dynamic dims remain ShapedType::kDynamic. + if (auto mrTy = dyn_cast(ty)) { + int64_t offset = ShapedType::kDynamic; + if (succeeded(mlir::pto::getPTOMemRefStridesAndOffset( + mrTy, info.viewStrides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } } } + if (info.viewStrides.empty()) + info.viewStrides.assign(info.viewShape.size(), ShapedType::kDynamic); return info; } diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index d6ae21050d..646ba66723 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -43,6 +43,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -190,6 +191,11 @@ static std::optional resolveTileHandle(Value tileBuf, materialize.getConfig()}; } + if (auto bind = tileBuf.getDefiningOp()) { + return TileHandleInfo{bind.getSource(), Value(), bind.getValidRow(), + bind.getValidCol(), bind.getConfig()}; + } + if (auto reshape = tileBuf.getDefiningOp()) { auto sourceInfo = resolveTileHandle(reshape.getSrc(), user); if (!sourceInfo) @@ -228,6 +234,50 @@ struct ViewChain { Value baseMemref; }; +struct RawViewInfo { + SmallVector sizes; + SmallVector strides; +}; + +static Value unwrapSingleResultCast(Value value) { + if (auto cast = value.getDefiningOp()) { + if (cast.getNumOperands() == 1 && cast.getNumResults() == 1) + return cast.getOperand(0); + } + return value; +} + +static std::optional traceRawViewInfo(Value tensorView, + Operation *user, + bool requireStrides) { + Value view = unwrapSingleResultCast(tensorView); + if (isa(view.getType())) + return std::nullopt; + + RawViewInfo info; + if (auto partition = view.getDefiningOp()) { + info.sizes.append(partition.getSizes().begin(), partition.getSizes().end()); + Value source = unwrapSingleResultCast(partition.getSource()); + if (auto makeView = source.getDefiningOp()) + info.strides.append(makeView.getStrides().begin(), + makeView.getStrides().end()); + else if (requireStrides) { + user->emitError("FoldTileBufIntrinsics: raw partition_view stride folding " + "requires a defining pto.make_tensor_view source"); + return std::nullopt; + } + return info; + } + + if (auto makeView = view.getDefiningOp()) { + info.sizes.append(makeView.getShape().begin(), makeView.getShape().end()); + info.strides.append(makeView.getStrides().begin(), makeView.getStrides().end()); + return info; + } + + return std::nullopt; +} + static std::optional traceViewChain(Value tensorView, Operation *user) { Value view = tensorView; @@ -274,6 +324,46 @@ static std::optional traceViewChain(Value tensorView, return chain; } +static LogicalResult refreshSubviewUsersForSourceReplacement( + Value originalSource, MemRefType replacementSourceType, Operation *anchor) { + SmallVector, 4> worklist; + llvm::SmallPtrSet visited; + worklist.push_back({originalSource, replacementSourceType}); + + while (!worklist.empty()) { + auto [source, sourceType] = worklist.pop_back_val(); + SmallVector subviews; + for (Operation *user : source.getUsers()) { + auto subview = dyn_cast(user); + if (subview && subview.getSource() == source) + subviews.push_back(subview); + } + + for (memref::SubViewOp subview : subviews) { + if (!visited.insert(subview.getOperation()).second) + continue; + + auto oldResultType = cast(subview.getResult().getType()); + auto inferredType = dyn_cast( + memref::SubViewOp::inferRankReducedResultType( + oldResultType.getShape(), sourceType, subview.getStaticOffsets(), + subview.getStaticSizes(), subview.getStaticStrides())); + if (!inferredType) { + anchor->emitError("FoldTileBufIntrinsics: failed to refresh " + "memref.subview result type after tile_buf_addr " + "source replacement"); + return failure(); + } + + if (subview.getResult().getType() != inferredType) + subview.getResult().setType(inferredType); + worklist.push_back({subview.getResult(), inferredType}); + } + } + + return success(); +} + static bool getConstIndexValue(Value v, int64_t &out) { if (auto cOp = v.getDefiningOp()) { out = cOp.value(); @@ -433,13 +523,6 @@ struct FoldTileBufIntrinsicsPass if (!handleInfo) return signalPassFailure(); - auto tileTy = dyn_cast(addrOp.getSrc().getType()); - if (!tileTy) { - addrOp.emitError("FoldTileBufIntrinsics: tile_buf_addr source must be " - "!pto.tile_buf"); - return signalPassFailure(); - } - if (auto resultMemrefType = dyn_cast(addrOp.getDst().getType())) { if (handleInfo->sourceMemref) { @@ -452,8 +535,13 @@ struct FoldTileBufIntrinsicsPass // The declared tile_buf_addr result type may differ from the actual // materialized source layout (e.g. plain shape vs. strided layout). - if (srcMemref.getType() != resultMemrefType) - addrOp.getDst().setType(cast(srcMemref.getType())); + if (srcMemref.getType() != resultMemrefType) { + auto srcMemrefType = cast(srcMemref.getType()); + if (failed(refreshSubviewUsersForSourceReplacement( + addrOp.getDst(), srcMemrefType, addrOp))) + return signalPassFailure(); + addrOp.getDst().setType(srcMemrefType); + } addrOp.getDst().replaceAllUsesWith(srcMemref); addrOp.erase(); continue; @@ -466,6 +554,13 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } + auto tileTy = dyn_cast(addrOp.getSrc().getType()); + if (!tileTy) { + addrOp.emitError("FoldTileBufIntrinsics: tile_buf_addr source must " + "be !pto.tile_buf when rebuilding from an addr"); + return signalPassFailure(); + } + builder.setInsertionPoint(addrOp); Value replacement = builder.create( addrOp.getLoc(), resultMemrefType, ValueRange{handleInfo->addr}, @@ -500,6 +595,13 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } + auto tileTy = dyn_cast(addrOp.getSrc().getType()); + if (!tileTy) { + addrOp.emitError("FoldTileBufIntrinsics: tile_buf_addr source must " + "be !pto.tile_buf when rebuilding from an addr"); + return signalPassFailure(); + } + builder.setInsertionPoint(addrOp); auto canonicalMemrefType = getCanonicalMemRefTypeForTileBuf(tileTy); memrefValue = builder.create( @@ -586,10 +688,6 @@ struct FoldTileBufIntrinsicsPass } for (auto dimOp : tvDimOps) { - auto chain = traceViewChain(dimOp.getTensorView(), dimOp); - if (!chain) - return signalPassFailure(); - int64_t dimIdx = 0; if (!getConstIndexValue(dimOp.getDimIndex(), dimIdx)) { dimOp.emitError( @@ -598,6 +696,24 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } + if (auto rawInfo = + traceRawViewInfo(dimOp.getTensorView(), dimOp, + /*requireStrides=*/false)) { + if (dimIdx < 0 || dimIdx >= static_cast(rawInfo->sizes.size())) { + dimOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_dim dim index out of " + "bounds"); + return signalPassFailure(); + } + dimOp.getResult().replaceAllUsesWith(rawInfo->sizes[dimIdx]); + dimOp.erase(); + continue; + } + + auto chain = traceViewChain(dimOp.getTensorView(), dimOp); + if (!chain) + return signalPassFailure(); + auto svTy = cast(chain->subview.getType()); if (dimIdx < 0 || dimIdx >= svTy.getRank()) { dimOp.emitError( @@ -622,10 +738,6 @@ struct FoldTileBufIntrinsicsPass } for (auto strideOp : tvStrideOps) { - auto chain = traceViewChain(strideOp.getTensorView(), strideOp); - if (!chain) - return signalPassFailure(); - int64_t dimIdx = 0; if (!getConstIndexValue(strideOp.getDimIndex(), dimIdx)) { strideOp.emitError( @@ -634,6 +746,25 @@ struct FoldTileBufIntrinsicsPass return signalPassFailure(); } + if (auto rawInfo = + traceRawViewInfo(strideOp.getTensorView(), strideOp, + /*requireStrides=*/true)) { + if (dimIdx < 0 || + dimIdx >= static_cast(rawInfo->strides.size())) { + strideOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_stride dim index out " + "of bounds"); + return signalPassFailure(); + } + strideOp.getResult().replaceAllUsesWith(rawInfo->strides[dimIdx]); + strideOp.erase(); + continue; + } + + auto chain = traceViewChain(strideOp.getTensorView(), strideOp); + if (!chain) + return signalPassFailure(); + auto svTy = cast(chain->subview.getType()); if (dimIdx < 0 || dimIdx >= svTy.getRank()) { strideOp.emitError( diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 158bd9e42f..ed7d7dbdfc 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1226,6 +1226,15 @@ static LogicalResult lowerPtrLikeTileBufAddrOps(func::FuncOp func, for (auto op : addrOps) { if (!isa(op.getDst().getType())) continue; + if (llvm::any_of(op.getDst().getUsers(), [](Operation *user) { + return isa(user); + })) + continue; auto targetType = dyn_cast(convertPTOTypeToMemRef(op.getDst().getType())); @@ -1605,6 +1614,57 @@ static bool hasStructuredViewSurface(func::FuncOp func) { return found; } +static bool isPipeInitializeTensorViewUse(OpOperand &use) { + return isa( + use.getOwner()); +} + +static bool hasFrontendPipeInitialize(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa( + op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static bool isTileViewABIType(Type type) { + return isa(type); +} + +static bool hasTileViewABIBoundary(func::FuncOp func) { + FunctionType functionType = func.getFunctionType(); + for (Type type : functionType.getInputs()) + if (isTileViewABIType(type)) + return true; + for (Type type : functionType.getResults()) + if (isTileViewABIType(type)) + return true; + + bool found = false; + func.walk([&](func::CallOp call) { + for (Value operand : call.getArgOperands()) { + if (isTileViewABIType(operand.getType())) { + found = true; + return WalkResult::interrupt(); + } + } + for (Type type : call.getResultTypes()) { + if (isTileViewABIType(type)) { + found = true; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return found; +} + // ============================================================================= // The Pass Implementation // ============================================================================= @@ -1619,7 +1679,13 @@ struct PTOViewToMemrefPass ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - for (auto func : mod.getOps()) { + SmallVector funcs; + mod.walk([&](func::FuncOp func) { funcs.push_back(func); }); + + for (auto func : funcs) { + if (hasFrontendPipeInitialize(func) || hasTileViewABIBoundary(func)) + continue; + if (viewOnly) { if (func.isExternal() || !hasStructuredViewSurface(func)) continue; @@ -2125,6 +2191,16 @@ struct PTOViewToMemrefPass func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); for (auto op : makeViews) { + bool hasMemRefUsers = false; + for (OpOperand &use : op->getUses()) { + if (!isPipeInitializeTensorViewUse(use)) { + hasMemRefUsers = true; + break; + } + } + if (!hasMemRefUsers) + continue; + IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); Location loc = op.getLoc(); @@ -2188,7 +2264,12 @@ struct PTOViewToMemrefPass rc->setAttr("layout", layoutAttr); } - rewriter.replaceOp(op, rc.getResult()); + op.getResult().replaceUsesWithIf( + rc.getResult(), [](OpOperand &use) { + return !isPipeInitializeTensorViewUse(use); + }); + if (op->use_empty()) + rewriter.eraseOp(op); } if (!viewOnly) { diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index 404c6634aa..9443314163 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -82,6 +82,11 @@ static bool hasPtrNormalizeConvertibleType(TypeRange types) { types, [](Type type) { return hasPtrNormalizeConvertibleType(type); }); } +static bool needsReinterpretCastPtrConversion(memref::ReinterpretCastOp op) { + return hasPtrNormalizeConvertibleType(op.getSource().getType()) || + hasPtrNormalizeConvertibleType(op.getType()); +} + static FailureOr> convertTypes(const TypeConverter &typeConverter, TypeRange types) { SmallVector convertedTypes; @@ -183,8 +188,6 @@ static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, Location loc) { if (!source) return {}; - if (isa(source.getType())) - return source; if (auto cast = source.getDefiningOp()) { if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) @@ -195,9 +198,58 @@ static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, return materializeScalarAccessPtr(input, rewriter, loc); } + if (isa(source.getType())) + return source; + if (auto cast = source.getDefiningOp()) return materializeScalarAccessPtr(cast.getSource(), rewriter, loc); + if (auto cast = source.getDefiningOp()) { + auto resultType = dyn_cast(source.getType()); + if (!resultType) + return {}; + + Value basePtr = materializeScalarAccessPtr(cast.getSource(), rewriter, loc); + if (!basePtr) + return {}; + + auto memorySpace = + getPointerMemorySpace(resultType.getMemorySpace(), rewriter.getContext()); + if (!memorySpace) + return {}; + + auto ptrType = pto::PtrType::get(rewriter.getContext(), + resultType.getElementType(), memorySpace); + if (basePtr.getType() != ptrType) + basePtr = rewriter.create(loc, ptrType, basePtr); + + SmallVector offsets = cast.getMixedOffsets(); + if (offsets.empty()) + return basePtr; + if (offsets.size() != 1) + return {}; + + Value offset; + if (auto attr = dyn_cast(offsets.front())) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return {}; + if (intAttr.getInt() == 0) + return basePtr; + offset = rewriter.create(loc, intAttr.getInt()); + } else { + offset = cast.getOffsets().front(); + if (!offset.getType().isIndex()) { + if (!isa(offset.getType())) + return {}; + offset = + rewriter.create(loc, rewriter.getIndexType(), offset); + } + } + + return rewriter.create(loc, ptrType, basePtr, offset); + } + if (auto subview = source.getDefiningOp()) { if (!needsSubviewPtrConversion(subview)) return {}; @@ -228,6 +280,18 @@ static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, Value addr = pointerCast.getAddrs().front(); if (isa(addr.getType())) return addr; + if (isa(addr.getType())) { + auto memrefType = dyn_cast(pointerCast.getResult().getType()); + if (!memrefType) + return {}; + auto memorySpace = + getPointerMemorySpace(memrefType.getMemorySpace(), rewriter.getContext()); + if (!memorySpace) + return {}; + auto ptrType = pto::PtrType::get(rewriter.getContext(), + memrefType.getElementType(), memorySpace); + return rewriter.create(loc, ptrType, addr); + } return materializeScalarAccessPtr(addr, rewriter, loc); } @@ -237,6 +301,44 @@ static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, return {}; } +static Value materializeReinterpretCastPtr(memref::ReinterpretCastOp cast, + Value convertedSource, + pto::PtrType ptrType, + PatternRewriter &rewriter) { + Location loc = cast.getLoc(); + Value basePtr = materializeScalarAccessPtr(convertedSource, rewriter, loc); + if (!basePtr) + return {}; + if (basePtr.getType() != ptrType) + basePtr = rewriter.create(loc, ptrType, basePtr); + + SmallVector offsets = cast.getMixedOffsets(); + if (offsets.empty()) + return basePtr; + if (offsets.size() != 1) + return {}; + + Value offset; + if (auto attr = dyn_cast(offsets.front())) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return {}; + if (intAttr.getInt() == 0) + return basePtr; + offset = rewriter.create(loc, intAttr.getInt()); + } else { + offset = cast.getOffsets().front(); + if (!offset.getType().isIndex()) { + if (!isa(offset.getType())) + return {}; + offset = + rewriter.create(loc, rewriter.getIndexType(), offset); + } + } + + return rewriter.create(loc, ptrType, basePtr, offset); +} + static Value materializeBoundaryOperandPtr(Value source, PatternRewriter &rewriter, Location loc) { @@ -736,6 +838,65 @@ struct ConvertStoreOperandToPtrPattern } }; +template +struct ConvertPredicateLoadOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value source = + materializeScalarAccessPtr(adaptor.getSource(), rewriter, op.getLoc()); + if (!source) + return rewriter.notifyMatchFailure( + op, "failed to materialize predicate load source ptr"); + if (!isa(source.getType())) + return rewriter.notifyMatchFailure( + op, "expected ptr-form predicate load source"); + + FailureOr> resultTypes = + convertTypes(*this->getTypeConverter(), op->getResultTypes()); + if (failed(resultTypes)) + return failure(); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands({source, adaptor.getOffset()}); + state.addTypes(*resultTypes); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +template +struct ConvertPredicateStoreOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(StoreOp op, typename StoreOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value destination = materializeScalarAccessPtr(adaptor.getDestination(), + rewriter, op.getLoc()); + if (!destination) + return rewriter.notifyMatchFailure( + op, "failed to materialize predicate store destination ptr"); + if (!isa(destination.getType())) + return rewriter.notifyMatchFailure( + op, "expected ptr-form predicate store destination"); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands({adaptor.getValue(), destination, adaptor.getOffset()}); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + struct ConvertPtrNormalizeUnrealizedCastOp final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -795,6 +956,32 @@ struct ConvertPtrNormalizeMemRefCastOp final } }; +struct ConvertPtrNormalizeReinterpretCastOp final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!needsReinterpretCastPtrConversion(op)) + return failure(); + + auto ptrType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!ptrType) + return rewriter.notifyMatchFailure(op, "expected ptr result type"); + + Value replacement = + materializeReinterpretCastPtr(op, adaptor.getSource(), ptrType, rewriter); + if (!replacement) + return rewriter.notifyMatchFailure( + op, "failed to materialize reinterpret_cast ptr"); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + struct VPTOPtrNormalizePass : public pto::impl::VPTOPtrNormalizeBase { using pto::impl::VPTOPtrNormalizeBase< @@ -836,6 +1023,10 @@ struct VPTOPtrNormalizePass return !hasPtrNormalizeConvertibleType(op.getSource().getType()) && !hasPtrNormalizeConvertibleType(op.getType()); }); + target.addDynamicallyLegalOp( + [](memref::ReinterpretCastOp op) { + return !needsReinterpretCastPtrConversion(op); + }); target.addDynamicallyLegalOp([&](pto::TileBufAddrOp op) { return op.getDst().getType() == typeConverter.convertType(op.getDst().getType()); @@ -932,6 +1123,14 @@ struct VPTOPtrNormalizePass [](pto::PTOLoadOp op) { return isa(op.getPtr().getType()); }); target.addDynamicallyLegalOp( [](pto::PTOStoreOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp( + [](pto::PldsOp op) { return isa(op.getSource().getType()); }); + target.addDynamicallyLegalOp( + [](pto::PldiOp op) { return isa(op.getSource().getType()); }); + target.addDynamicallyLegalOp( + [](pto::PstsOp op) { return isa(op.getDestination().getType()); }); + target.addDynamicallyLegalOp( + [](pto::PstiOp op) { return isa(op.getDestination().getType()); }); target.addDynamicallyLegalOp( [](memref::SubViewOp op) { return !needsSubviewPtrConversion(op); }); @@ -966,8 +1165,13 @@ struct VPTOPtrNormalizePass ConvertMteUbGmOperandPattern, ConvertLoadOperandToPtrPattern, ConvertStoreOperandToPtrPattern, + ConvertPredicateLoadOperandToPtrPattern, + ConvertPredicateLoadOperandToPtrPattern, + ConvertPredicateStoreOperandToPtrPattern, + ConvertPredicateStoreOperandToPtrPattern, ConvertPtrNormalizeUnrealizedCastOp, - ConvertPtrNormalizeMemRefCastOp>( + ConvertPtrNormalizeMemRefCastOp, + ConvertPtrNormalizeReinterpretCastOp>( typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/test/dsl-st/npu_a5/tcolexpand.py b/test/dsl-st/npu_a5/tcolexpand.py index 4d30dd5db4..99d3ecaea8 100644 --- a/test/dsl-st/npu_a5/tcolexpand.py +++ b/test/dsl-st/npu_a5/tcolexpand.py @@ -120,8 +120,13 @@ def _tcolexpand_kernel( dst_tile = pto.alloc_tile(shape=[DST_ROWS, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) pto.tile.load(src_part, src_tile) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) pto.tile.colexpand(src_tile, dst_tile) + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) pto.tile.store(dst_tile, dst_part) + pto.pipe_barrier(pto.Pipe.ALL) def _make_input(): diff --git a/test/dsl-st/npu_a5/tcolsum.py b/test/dsl-st/npu_a5/tcolsum.py index a6d898e46f..72de4aab78 100644 --- a/test/dsl-st/npu_a5/tcolsum.py +++ b/test/dsl-st/npu_a5/tcolsum.py @@ -119,8 +119,13 @@ def _tcolsum_kernel( dst_tile = pto.alloc_tile(shape=[1, COLS], dtype=pto.f32, addr=DST_TILE_ADDR) pto.tile.load(src_part, src_tile) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) pto.tile.colsum(src_tile, dst_tile) + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) pto.tile.store(dst_tile, dst_part) + pto.pipe_barrier(pto.Pipe.ALL) def _make_input(): diff --git a/test/dsl-st/npu_a5/tload_store.py b/test/dsl-st/npu_a5/tload_store.py index 595ef15183..73e1aa8d22 100644 --- a/test/dsl-st/npu_a5/tload_store.py +++ b/test/dsl-st/npu_a5/tload_store.py @@ -139,7 +139,10 @@ def _roundtrip_body(src_ptr, dst_ptr, *, rows, cols, view_strides=None, tile_kwa ) pto.tile.load(src_part, tile) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.MTE3, event_id=0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.MTE3, event_id=0) pto.tile.store(tile, dst_part) + pto.pipe_barrier(pto.Pipe.ALL) _tload_store_kernels = {} diff --git a/test/dsl-st/predicate_pack.py b/test/dsl-st/predicate_pack.py index 01876ce133..9d32b4dfc4 100644 --- a/test/dsl-st/predicate_pack.py +++ b/test/dsl-st/predicate_pack.py @@ -88,31 +88,30 @@ def predicate_pack_part_kernel( pto.set_flag("MTE2", "V", event_id=0) pto.wait_flag("MTE2", "V", event_id=0) - with pto.simd(): - seed = pto.pset_b8(pto.MaskPattern.ALL) - src = pto.vlds(src_tile[0, 0:]) - active_b8 = pto.vcmp(src, src, seed, pto.CmpMode.EQ) - active = pto.pbitcast(active_b8, pto.mask_b32) - - packed_lo_same = pto.ppack(active, "LOWER") - unpacked_lo_same = pto.punpack(packed_lo_same, "LOWER") - packed_lo_b16 = pto.ppack(active, "LOWER", to_type=pto.mask_b16) - unpacked_lo_b32 = pto.punpack(packed_lo_b16, "LOWER", to_type=pto.mask_b32) - - packed_hi_same = pto.ppack(active, "HIGHER") - unpacked_hi_same = pto.punpack(packed_hi_same, "HIGHER") - packed_hi_b16 = pto.ppack(active, "HIGHER", to_type=pto.mask_b16) - unpacked_hi_b32 = pto.punpack(packed_hi_b16, "HIGHER", to_type=pto.mask_b32) - - pto.psts(active, dst_tile.as_ptr(), 0, dist="NORM") - pto.psts(packed_lo_same, dst_tile.as_ptr(), ROW_BYTES, dist="NORM") - pto.psts(unpacked_lo_same, dst_tile.as_ptr(), ROW_BYTES * 2, dist="NORM") - pto.psts(packed_lo_b16, dst_tile.as_ptr(), ROW_BYTES * 3, dist="NORM") - pto.psts(unpacked_lo_b32, dst_tile.as_ptr(), ROW_BYTES * 4, dist="NORM") - pto.psts(packed_hi_same, dst_tile.as_ptr(), ROW_BYTES * 5, dist="NORM") - pto.psts(unpacked_hi_same, dst_tile.as_ptr(), ROW_BYTES * 6, dist="NORM") - pto.psts(packed_hi_b16, dst_tile.as_ptr(), ROW_BYTES * 7, dist="NORM") - pto.psts(unpacked_hi_b32, dst_tile.as_ptr(), ROW_BYTES * 8, dist="NORM") + seed = pto.pset_b8(pto.MaskPattern.ALL) + src = pto.vlds(src_tile[0, 0:]) + active_b8 = pto.vcmp(src, src, seed, pto.CmpMode.EQ) + active = pto.pbitcast(active_b8, pto.mask_b32) + + packed_lo_same = pto.ppack(active, "LOWER") + unpacked_lo_same = pto.punpack(packed_lo_same, "LOWER") + packed_lo_b16 = pto.ppack(active, "LOWER", to_type=pto.mask_b16) + unpacked_lo_b32 = pto.punpack(packed_lo_b16, "LOWER", to_type=pto.mask_b32) + + packed_hi_same = pto.ppack(active, "HIGHER") + unpacked_hi_same = pto.punpack(packed_hi_same, "HIGHER") + packed_hi_b16 = pto.ppack(active, "HIGHER", to_type=pto.mask_b16) + unpacked_hi_b32 = pto.punpack(packed_hi_b16, "HIGHER", to_type=pto.mask_b32) + + pto.psts(active, dst_tile.as_ptr(), 0, dist="NORM") + pto.psts(packed_lo_same, dst_tile.as_ptr(), ROW_BYTES, dist="NORM") + pto.psts(unpacked_lo_same, dst_tile.as_ptr(), ROW_BYTES * 2, dist="NORM") + pto.psts(packed_lo_b16, dst_tile.as_ptr(), ROW_BYTES * 3, dist="NORM") + pto.psts(unpacked_lo_b32, dst_tile.as_ptr(), ROW_BYTES * 4, dist="NORM") + pto.psts(packed_hi_same, dst_tile.as_ptr(), ROW_BYTES * 5, dist="NORM") + pto.psts(unpacked_hi_same, dst_tile.as_ptr(), ROW_BYTES * 6, dist="NORM") + pto.psts(packed_hi_b16, dst_tile.as_ptr(), ROW_BYTES * 7, dist="NORM") + pto.psts(unpacked_hi_b32, dst_tile.as_ptr(), ROW_BYTES * 8, dist="NORM") pto.set_flag("V", "MTE3", event_id=0) pto.wait_flag("V", "MTE3", event_id=0) diff --git a/test/lit/vpto/predicate_memref_ptr_normalize.pto b/test/lit/vpto/predicate_memref_ptr_normalize.pto new file mode 100644 index 0000000000..3934c75a3a --- /dev/null +++ b/test/lit/vpto/predicate_memref_ptr_normalize.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @predicate_memref_ptr_normalize() attributes {pto.entry} { + %c0_i64 = arith.constant 0 : i64 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %tile = pto.pointer_cast(%c0_i64) {config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode>} : memref<1x32xui8, strided<[32, 1], offset: ?>, #pto.address_space> + %bound = pto.bind_tile %tile, %c32, %c32 {config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode>} : memref<1x32xui8, strided<[32, 1], offset: ?>, #pto.address_space> -> memref<1x32xui8, strided<[32, 1], offset: ?>, #pto.address_space> + %flat = memref.reinterpret_cast %bound to offset: [0], sizes: [%c32], strides: [1] : memref<1x32xui8, strided<[32, 1], offset: ?>, #pto.address_space> to memref> + %mask = pto.pset_b8 "PAT_ALL" : !pto.mask + pto.psts %mask, %flat[%c0], "NORM" : !pto.mask, memref>, index + %loaded = pto.plds %flat[%c0], "NORM" : memref>, index -> !pto.mask + pto.psts %loaded, %flat[%c0], "NORM" : !pto.mask, memref>, index + return + } + + func.func @vsts_reinterpret_ptr_normalize() attributes {pto.entry} { + %c0_i64 = arith.constant 0 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %tile = pto.pointer_cast(%c0_i64) {config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode>} : memref<1x64xf32, strided<[64, 1], offset: ?>, #pto.address_space> + %bound = pto.bind_tile %tile, %c1, %c64 {config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode>} : memref<1x64xf32, strided<[64, 1], offset: ?>, #pto.address_space> -> memref<1x64xf32, strided<[64, 1], offset: ?>, #pto.address_space> + %flat = memref.reinterpret_cast %bound to offset: [0], sizes: [%c64], strides: [1] : memref<1x64xf32, strided<[64, 1], offset: ?>, #pto.address_space> to memref> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %value = arith.constant 0.000000e+00 : f32 + %vec = pto.vdup %value, %mask : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %flat[%c0], %mask : !pto.vreg<64xf32>, memref>, !pto.mask + return + } + } +} + +// CHECK-LABEL: func.func @predicate_memref_ptr_normalize +// CHECK-NOT: memref.reinterpret_cast +// CHECK: pto.psts {{.*}} : !pto.mask, !pto.ptr, index +// CHECK: pto.plds {{.*}} : !pto.ptr, index -> !pto.mask +// CHECK-LABEL: func.func @vsts_reinterpret_ptr_normalize +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh index fa530911af..6ea838b4d0 100755 --- a/test/vpto/scripts/run_host_vpto_validation.sh +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -103,6 +103,84 @@ command -v python3 >/dev/null 2>&1 || die "python3 not found" mkdir -p "${WORK_SPACE}" WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" +has_torch_npu_packages() { + local python_bin="$1" + TORCH_DEVICE_BACKEND_AUTOLOAD=0 "${python_bin}" - <<'PY' >/dev/null 2>&1 +import importlib.util + +missing = [ + name for name in ("torch", "torch_npu") + if importlib.util.find_spec(name) is None +] +raise SystemExit(1 if missing else 0) +PY +} + +add_python_candidate() { + local candidate="$1" + [[ -n "${candidate}" ]] || return 0 + + local resolved + if [[ "${candidate}" == */* ]]; then + resolved="${candidate}" + else + resolved="$(command -v "${candidate}" 2>/dev/null || true)" + fi + + [[ -n "${resolved}" && -x "${resolved}" ]] || return 0 + local resolved_dir resolved_base + resolved_dir="$(cd "$(dirname "${resolved}")" && pwd -P)" + resolved_base="$(basename "${resolved}")" + printf '%s/%s\n' "${resolved_dir}" "${resolved_base}" +} + +resolve_ptodsl_runtime_python() { + local -a candidates=() + local candidate resolved + + for candidate in \ + "${PTO_PYTHON_BIN:-}" \ + "${PYTHON_BIN:-}" \ + "${PTO_DSL_ST_PYTHON_BIN:-}" \ + python3 \ + /home/mouliangyu/miniconda3/bin/python3 \ + /home/mouliangyu/miniconda3/bin/python \ + /home/zhoujiaming/miniconda3/bin/python3 \ + /home/zhoujiaming/miniconda3/bin/python + do + resolved="$(add_python_candidate "${candidate}")" + [[ -n "${resolved}" ]] || continue + case ":${candidates[*]}:" in + *":${resolved}:"*) ;; + *) candidates+=("${resolved}") ;; + esac + done + + shopt -s nullglob + for candidate in \ + /home/*/miniconda3/envs/*/bin/python \ + /home/*/anaconda3/envs/*/bin/python \ + /opt/conda/envs/*/bin/python + do + resolved="$(add_python_candidate "${candidate}")" + [[ -n "${resolved}" ]] || continue + case ":${candidates[*]}:" in + *":${resolved}:"*) ;; + *) candidates+=("${resolved}") ;; + esac + done + shopt -u nullglob + + for candidate in "${candidates[@]}"; do + if has_torch_npu_packages "${candidate}"; then + printf '%s\n' "${candidate}" + return 0 + fi + done + + return 1 +} + is_ptodsl_case_dir() { [[ -f "$1/kernel.py" ]] } @@ -313,12 +391,18 @@ run_ptodsl_case() { local out_dir="$3" local case_script case_script="$(ptodsl_case_script "${case_path}")" + local ptodsl_python + ptodsl_python="$(resolve_ptodsl_runtime_python)" || + die "PTODSL source-backed case ${case_name} requires an existing Python runtime with torch and torch_npu" - log "[$case_name] run PTODSL source-backed case" + log "[$case_name] run PTODSL source-backed case with ${ptodsl_python}" ( cd "${out_dir}" export PTODSL_CACHE_DIR="${out_dir}/ptodsl-cache" export PATH="$(dirname "${PTOAS_BIN}"):${PATH}" + export PYTHON_BIN="${ptodsl_python}" + export PTO_PYTHON_BIN="${ptodsl_python}" + export TORCH_DEVICE_BACKEND_AUTOLOAD=0 if [[ "${DEVICE}" == "SIM" ]]; then "${ROOT_DIR}/scripts/sim_dsl.sh" \ --soc-version "${PTODSL_SIM_SOC_VERSION}" \ @@ -326,7 +410,7 @@ run_ptodsl_case() { "${case_script}" else export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" - python3 "${case_script}" + "${ptodsl_python}" "${case_script}" fi ) log "[$case_name] output dir: ${out_dir}" diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index ae5f23a941..3f49afea90 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1769,16 +1769,14 @@ static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); - pto::PTOViewToMemrefOptions viewOnlyRerunOpts; - viewOnlyRerunOpts.viewOnly = true; - // ExpandTileOp materializes fresh TileLang helper IR after the shared - // mainline has already lowered the original module's tensor_view surface. - // Re-run the shared view lowering here so helper-local - // pto.make_tensor_view/pto.partition_view chains do not leak into - // FoldTileBufIntrinsics. - kernelModulePM.addPass(pto::createPTOViewToMemrefPass(viewOnlyRerunOpts)); kernelModulePM.addNestedPass( pto::createFoldTileBufIntrinsicsPass("shape-only")); + // ExpandTileOp + inline removes TileLang tile-op calls after the shared + // mainline. PTODSL subkernel-helper ABI may have forced that mainline to + // preserve the caller's tile/view surface, so rerun the full lowering here + // instead of view-only lowering: raw make_tensor_view bases may still be + // pto.ptr values until this point. + kernelModulePM.addPass(pto::createPTOViewToMemrefPass()); if (enableA5VPTOPostLoweringFusionLifecycle) { kernelModulePM.addPass(pto::createPTOLowLevelLoopFusionPass()); kernelModulePM.addPass(mlir::createCanonicalizerPass());