Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 315 additions & 58 deletions .github/workflows/ci_sim.yml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {});

std::unique_ptr<Pass> createPTORemoveRedundantBarrierPass();
std::unique_ptr<Pass> createPTOViewToMemrefPass();
std::unique_ptr<Pass>
createPTOViewToMemrefPass(const PTOViewToMemrefOptions &options);
std::unique_ptr<Pass> createPTOValidateIntToPtrUsesPass();
std::unique_ptr<Pass> createPTOMaterializeTileHandlesPass();
std::unique_ptr<Pass> createInferPTOLayoutPass();
Expand Down
18 changes: 11 additions & 7 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,13 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu
- pto.tile_valid_cols → same as above for v_col

tensor_view family:
- pto.tensor_view_addr → traces through unrealized_conversion_cast →
subview → reinterpret_cast, then folds to the base memref or to
pto.castptr/pto.addptr on the base memref
- pto.get_tensor_view_dim → folded to arith.constant for static subview
sizes, or to the subview size SSA operand for dynamic dims
- pto.get_tensor_view_stride → folded to the reinterpret_cast stride
operand, multiplied by the subview stride when needed
- pto.tensor_view_addr → traces through
unrealized_conversion_cast → subview → reinterpret_cast, then folds to
the base memref or to pto.castptr/pto.addptr on the base pointer
- pto.get_tensor_view_dim → folded to arith.constant for static view sizes,
or to the source size SSA operand for dynamic dims
- pto.get_tensor_view_stride → folded to the lowered reinterpret_cast
stride, multiplied by the subview stride when needed

Dead unrealized_conversion_cast, memref.subview, and
memref.reinterpret_cast ops exposed by folding are cleaned up after the
Expand Down Expand Up @@ -624,6 +624,10 @@ def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> {
}];

let constructor = "mlir::pto::createPTOViewToMemrefPass()";
let options = [
Option<"viewOnly", "view-only", "bool", /*default=*/"false",
"Only rerun structured tensor_view lowering without rewriting tile or compute surfaces">
];

let dependentDialects = [
"mlir::pto::PTODialect",
Expand Down
62 changes: 53 additions & 9 deletions lib/PTO/Transforms/ExpandTileOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ namespace {
// Four kinds of operands:
// Tile — from TileBufType. dtype + shape + memorySpace + config
// all participate in the specialization key (SpecKey).
// View — from MemRefType (lowered PartitionTensorViewType). The element
// dtype and optional explicit layout participate in SpecKey;
// shape/strides/memorySpace remain JSON-only metadata for Python
// constraint checking and must not perturb C++ codegen caching.
// View — from MemRefType (lowered TensorView/PartitionTensorView).
// dtype, shape, strides, memorySpace, and optional explicit layout
// participate in SpecKey because they affect template selection and
// generated DMA parameters for tload/tstore.
// Vector — from builtin VectorType. The element dtype and vector shape
// participate in SpecKey so helper-side schema filtering can
// distinguish auxiliary vector operands such as tmrgsort's
Expand All @@ -106,7 +106,7 @@ struct OperandTypeInfo {
int32_t fractal = 0;
uint64_t pad = 0;

// --- View-only (MemRefType) — for JSON / constraint checking only ---
// --- View-only ---
SmallVector<int64_t> viewShape;
SmallVector<int64_t> viewStrides;
std::string viewMemorySpace; // "gm" or "ub"
Expand All @@ -132,8 +132,8 @@ struct OperandTypeInfo {
return vectorShape == rhs.vectorShape;
if (kind == OperandKind::Scalar)
return scalarValue == rhs.scalarValue;
// View: dtype + explicit layout are sufficient for template caching.
return viewLayout == rhs.viewLayout;
return viewShape == rhs.viewShape && viewStrides == rhs.viewStrides &&
viewMemorySpace == rhs.viewMemorySpace && viewLayout == rhs.viewLayout;
}
};

Expand Down Expand Up @@ -177,7 +177,11 @@ struct SpecKeyInfo : public llvm::DenseMapInfo<SpecKey> {
h = llvm::hash_combine(h, *op.scalarValue);
}
if (op.kind == OperandKind::View) {
h = llvm::hash_combine(h, op.viewLayout.has_value());
h = llvm::hash_combine(h, op.viewMemorySpace, op.viewLayout.has_value());
for (int64_t d : op.viewShape)
h = llvm::hash_combine(h, d);
for (int64_t d : op.viewStrides)
h = llvm::hash_combine(h, d);
if (op.viewLayout)
h = llvm::hash_combine(h, static_cast<int>(*op.viewLayout));
}
Expand Down Expand Up @@ -630,7 +634,7 @@ static std::optional<OperandTypeInfo> buildOperandTypeInfo(Value value) {
return info;
}

// View operand — from MemRefType (lowered PartitionTensorViewType).
// View operand — from MemRefType (lowered TensorView / PartitionTensorView).
if (auto mrTy = dyn_cast<MemRefType>(ty)) {
OperandTypeInfo info;
info.kind = OperandKind::View;
Expand Down Expand Up @@ -839,6 +843,11 @@ static std::string buildUniqueFunctionBaseName(const SpecKey &key) {
uniqueName += "_fr" + std::to_string(op.fractal);
uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false);
} else if (op.kind == OperandKind::View) {
for (int64_t d : op.viewShape)
uniqueName += "_s" + std::to_string(d);
for (int64_t d : op.viewStrides)
uniqueName += "_st" + std::to_string(d);
uniqueName += "_ms_" + op.viewMemorySpace;
if (op.viewLayout)
uniqueName += "_vl_" + stringifyLayout(*op.viewLayout).str();
} else if (op.kind == OperandKind::Vector) {
Expand Down Expand Up @@ -869,6 +878,39 @@ static std::string buildContextAttrsJson(const SpecKey &key) {
return json;
}

static bool isViewLikeType(Type type) {
return isa<pto::TensorViewType, pto::PartitionTensorViewType, MemRefType>(type);
}

static void specializeTemplateEntryArgumentTypes(func::FuncOp fn,
Operation *tileOp) {
if (!fn || fn.isExternal())
return;

FunctionType fnTy = fn.getFunctionType();
SmallVector<Type> inputs(fnTy.getInputs().begin(), fnTy.getInputs().end());
bool changed = false;
unsigned operandCount = std::min<unsigned>(tileOp->getNumOperands(),
inputs.size());
for (unsigned i = 0; i < operandCount; ++i) {
Type callerTy = tileOp->getOperand(i).getType();
Type calleeTy = inputs[i];
if (callerTy == calleeTy)
continue;
if (!isViewLikeType(callerTy) || !isViewLikeType(calleeTy))
continue;
inputs[i] = callerTy;
fn.getArgument(i).setType(callerTy);
changed = true;
}

if (!changed)
return;

fn.setFunctionType(FunctionType::get(fn.getContext(), inputs,
fnTy.getResults()));
}

// ============================================================================
// Invoke Python DSL daemon RPC to generate a specialized template function.
// ============================================================================
Expand Down Expand Up @@ -1028,6 +1070,7 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key,
}

auto cloned = clonedFuncs.front();
specializeTemplateEntryArgumentTypes(cloned, tileOp);
if (!cloned->hasAttr("pto.tilelang.instance")) {
llvm::errs() << "ExpandTileOp: warning: daemon output function @"
<< cloned.getSymName()
Expand Down Expand Up @@ -1227,6 +1270,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key,
}

auto cloned = clonedFuncs.front();
specializeTemplateEntryArgumentTypes(cloned, tileOp);
// The pto.tilelang.instance attribute should already be set by the
// TileLang DSL frontend in the generated MLIR. Verify it exists.
if (!cloned->hasAttr("pto.tilelang.instance")) {
Expand Down
85 changes: 49 additions & 36 deletions lib/PTO/Transforms/FoldTileBufIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
// For tile_buf intrinsics, the active VPTO path folds against materialized tile
// handles produced by the shared tile-handle bridge (`pto.alloc_tile` or
// `pto.materialize_tile`).
// For tensor_view intrinsics, the pass traces through the full
// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast
// chain to fold directly to constants or SSA operands from the
// reinterpret_cast, without generating intermediate memref.dim /
// memref.extract_strided_metadata ops.
// For tensor_view intrinsics, the pass traces through the lowered
// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast chain
// to fold directly to constants or SSA operands, without generating
// intermediate memref.dim / memref.extract_strided_metadata ops.
//
//===----------------------------------------------------------------------===//

Expand All @@ -42,6 +41,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
Expand Down Expand Up @@ -90,6 +90,19 @@ static void eraseDeadAllocTileOps(func::FuncOp func) {
alloc.erase();
}

static bool isDeadPTODSLSubkernelHelper(func::FuncOp func) {
if (!func->hasAttr("pto.ptodsl.subkernel_helper"))
return false;

auto module = func->getParentOfType<ModuleOp>();
if (!module)
return false;

SymbolTable symbolTable(module);
auto uses = symbolTable.getSymbolUses(func, module);
return uses && uses->empty();
}
Comment on lines +93 to +104

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constructing a SymbolTable on the parent ModuleOp and calling getSymbolUses inside a FuncOp pass violates MLIR's pass nesting and concurrency model. Since FuncOp passes can be scheduled to run concurrently on different functions, traversing the parent module to find symbol uses can lead to data races, undefined behavior, or crashes in multi-threaded mode.

To resolve this, consider either:

  1. Changing FoldTileBufIntrinsics to a ModulePass so it can safely perform module-wide symbol analysis.
  2. Avoiding the use-check entirely in this pass (e.g., by relying on a subsequent dead-symbol-elimination pass to clean up dead functions, or simply skipping any function with the pto.ptodsl.subkernel_helper attribute if they are always intended to be inlined).


struct TileHandleInfo {
Value sourceMemref;
Value addr;
Expand Down Expand Up @@ -210,36 +223,29 @@ static MemRefType getCanonicalMemRefTypeForTileBuf(pto::TileBufType tileTy) {
}

struct ViewChain {
UnrealizedConversionCastOp cast;
memref::SubViewOp subview;
memref::ReinterpretCastOp reinterpretCast;
Value baseMemref;
};

static std::optional<ViewChain> traceViewChain(Value tensorView,
Operation *user) {
Value memrefVal;
UnrealizedConversionCastOp castOp;

if (isa<MemRefType>(tensorView.getType())) {
memrefVal = tensorView;
} else {
castOp = tensorView.getDefiningOp<UnrealizedConversionCastOp>();
if (!castOp || castOp.getNumOperands() != 1) {
user->emitError(
"FoldTileBufIntrinsics: expected tensor_view to be defined by a "
"single-operand builtin.unrealized_conversion_cast");
return std::nullopt;
}
memrefVal = castOp.getOperand(0);
if (!isa<MemRefType>(memrefVal.getType())) {
user->emitError(
"FoldTileBufIntrinsics: expected cast operand to be a memref, got ")
<< memrefVal.getType();
return std::nullopt;
}
Value view = tensorView;

if (auto cast = view.getDefiningOp<UnrealizedConversionCastOp>()) {
if (cast.getNumOperands() == 1 && cast.getNumResults() == 1)
view = cast.getOperand(0);
}

if (!isa<MemRefType>(view.getType())) {
user->emitError("FoldTileBufIntrinsics: expected tensor_view to be lowered "
"to a memref.subview chain before folding, got ")
<< (view.getDefiningOp() ? view.getDefiningOp()->getName().getStringRef()
: StringRef("block argument"));
return std::nullopt;
}

Value memrefVal = view;
auto subviewOp = memrefVal.getDefiningOp<memref::SubViewOp>();
if (!subviewOp) {
user->emitError("FoldTileBufIntrinsics: expected memref to be defined by "
Expand All @@ -261,7 +267,11 @@ static std::optional<ViewChain> traceViewChain(Value tensorView,
return std::nullopt;
}

return ViewChain{castOp, subviewOp, rcOp, rcOp.getSource()};
ViewChain chain;
chain.subview = subviewOp;
chain.reinterpretCast = rcOp;
chain.baseMemref = rcOp.getSource();
return chain;
}

static bool getConstIndexValue(Value v, int64_t &out) {
Expand Down Expand Up @@ -380,12 +390,13 @@ struct FoldTileBufIntrinsicsPass
return signalPassFailure();
}

// Leftover TileLang template instances (private, uncalled after
// PTOInlineLibCall) still contain pto.tile_buf_addr / tile_valid_*
// ops on tile_buf function arguments — they have no materialized tile
// handle anchor to fold against and will be removed by later DCE. Skip
// them.
if (func->hasAttr("pto.tilelang.instance"))
// Leftover TileLang template instances and already-inlined PTODSL
// subkernel helpers may still contain structured-view intrinsics on
// function arguments. Those formal arguments have no materialized
// call-site handle to fold against; the live caller body has already been
// inlined and folded separately.
if (func->hasAttr("pto.tilelang.instance") ||
isDeadPTODSLSubkernelHelper(func))
return;

SmallVector<pto::TileBufAddrOp, 8> addrOps;
Expand Down Expand Up @@ -667,14 +678,16 @@ struct FoldTileBufIntrinsicsPass
return signalPassFailure();
}

Value linearOffset =
Value linearOffset;
Value basePtr;
linearOffset =
computeLinearOffset(builder, addrOp.getLoc(),
chain->reinterpretCast.getMixedOffsets(),
chain->subview.getMixedOffsets(),
chain->reinterpretCast.getMixedStrides());

Value basePtr = builder.create<pto::CastPtrOp>(
basePtr = builder.create<pto::CastPtrOp>(
addrOp.getLoc(), resultPtrType, chain->baseMemref);

Value replacement =
linearOffset
? builder.create<pto::AddPtrOp>(addrOp.getLoc(), resultPtrType,
Expand Down
18 changes: 17 additions & 1 deletion lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,22 @@ static void eraseDeadMatchingPrivateFuncs(ModuleOp module,
}
}

static void eraseDeadPTODSLSubkernelHelpers(ModuleOp module) {
for (ModuleOp funcModule : collectFuncModules(module)) {
SymbolTable symbolTable(funcModule);
SmallVector<func::FuncOp, 8> deadFuncs;
for (func::FuncOp func : funcModule.getOps<func::FuncOp>()) {
if (!isPTODSLSubkernelHelperFunc(func))
continue;
auto uses = symbolTable.getSymbolUses(func, funcModule);
if (uses && uses->empty())
deadFuncs.push_back(func);
}
for (func::FuncOp func : deadFuncs)
func.erase();
}
}

struct PTOInlineBackendHelpersPass
: public pto::impl::PTOInlineBackendHelpersBase<
PTOInlineBackendHelpersPass> {
Expand All @@ -371,7 +387,7 @@ struct PTOInlineBackendHelpersPass
<< " call(s)\n";
}

eraseDeadMatchingPrivateFuncs(module, isInlineableBackendHelperFunc);
eraseDeadPTODSLSubkernelHelpers(module);
}
};

Expand Down
Loading
Loading