Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions lib/PTO/Transforms/FoldTileBufIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ struct FoldTileBufIntrinsicsPass
SmallVector<pto::TensorViewAddrOp, 8> tvAddrOps;
SmallVector<pto::GetTensorViewDimOp, 8> tvDimOps;
SmallVector<pto::GetTensorViewStrideOp, 8> tvStrideOps;
SmallVector<pto::GetValidShapeOp, 8> getValidShapeOps;

func.walk([&](Operation *op) {
if (auto addr = dyn_cast<pto::TileBufAddrOp>(op))
Expand All @@ -408,9 +409,50 @@ struct FoldTileBufIntrinsicsPass
tvDimOps.push_back(tvDim);
else if (auto tvStride = dyn_cast<pto::GetTensorViewStrideOp>(op))
tvStrideOps.push_back(tvStride);
else if (auto gvs = dyn_cast<pto::GetValidShapeOp>(op))
getValidShapeOps.push_back(gvs);
});

if (shouldFoldAddrFamily(*mode)) {
// Fold pto.get_validshape into the materialized tile handle
// valid_row / valid_col. This must precede tile_buf_addr and
// tile_valid_{rows,cols} folding: set_validshape operands are usually
// produced by get_validshape, so resolving them first lets
// resolveTileHandle observe the overridden valid shape carried by a
// treshape + set_validshape pair.
for (auto gvsOp : getValidShapeOps) {
if (!isa<pto::TileBufType>(gvsOp.getSource().getType()))
continue;

auto handleInfo = resolveTileHandle(gvsOp.getSource(), gvsOp);
if (!handleInfo)
return signalPassFailure();

builder.setInsertionPoint(gvsOp);
auto tileTy = cast<pto::TileBufType>(gvsOp.getSource().getType());
auto validShape = tileTy.getValidShape();

Value rowReplacement = handleInfo->validRow;
if (!validShape.empty() && validShape[0] != ShapedType::kDynamic)
rowReplacement =
builder.create<arith::ConstantIndexOp>(gvsOp.getLoc(), validShape[0]);

Value colReplacement = handleInfo->validCol;
if (validShape.size() >= 2 && validShape[1] != ShapedType::kDynamic)
colReplacement =
builder.create<arith::ConstantIndexOp>(gvsOp.getLoc(), validShape[1]);

if (!rowReplacement || !colReplacement) {
gvsOp.emitError("FoldTileBufIntrinsics: pto.get_validshape could not "
"resolve a concrete valid_row / valid_col");
return signalPassFailure();
}

gvsOp.getValidRow().replaceAllUsesWith(rowReplacement);
gvsOp.getValidCol().replaceAllUsesWith(colReplacement);
gvsOp.erase();
}

// Fold pto.tile_buf_addr by recovering the active materialized tile
// handle contract:
// - pto.materialize_tile → use the source memref directly
Expand Down Expand Up @@ -714,6 +756,38 @@ struct FoldTileBufIntrinsicsPass
op->erase();
}

// Erase pto.set_validshape ops. Every valid-shape reader
// (get_validshape / tile_valid_{rows,cols} / tile_buf_addr) has been
// folded above, so the runtime metadata writes have no remaining
// observer and have no LLVM lowering.
SmallVector<pto::SetValidShapeOp, 8> setValidShapeOps;
func.walk([&](pto::SetValidShapeOp op) { setValidShapeOps.push_back(op); });
for (auto op : llvm::reverse(setValidShapeOps))
op.erase();

// DCE tile-handle view / alloc ops left behind after valid-shape
// folding (treshape / materialize_tile / alloc_tile / bridging casts).
bool tileDceChanged = true;
while (tileDceChanged) {
tileDceChanged = false;
SmallVector<Operation *, 8> deadTileOps;
func.walk([&](Operation *op) {
if (!op->use_empty())
return;
if (isa<pto::TReshapeOp, pto::MaterializeTileOp, pto::AllocTileOp>(op))
deadTileOps.push_back(op);
else if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
if (castOp.getNumOperands() == 1 &&
isa<pto::TileBufType>(castOp.getResult(0).getType()))
deadTileOps.push_back(op);
}
});
for (auto *op : llvm::reverse(deadTileOps)) {
op->erase();
tileDceChanged = true;
}
}
Comment on lines +770 to +789

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current iterative DCE implementation performs a full function walk (func.walk) inside a while loop for every level of dead tile operations. For a chain of $N$ dead operations, this results in $O(N \times M)$ complexity (where $M$ is the number of operations in the function).

We can optimize this to $O(N + M)$ by using a worklist-based approach. When an operation is erased, we check if its operands have become dead and add them to the worklist if they match the target types. Additionally, we should verify that UnrealizedConversionCastOp has exactly 1 result before accessing getResult(0) to prevent potential out-of-bounds assertions.

    SmallVector<Operation *, 8> worklist;
    func.walk([&](Operation *op) {
      if (!op->use_empty())
        return;
      if (isa<pto::TReshapeOp, pto::MaterializeTileOp, pto::AllocTileOp>(op)) {
        worklist.push_back(op);
      } else if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
        if (castOp.getNumOperands() == 1 && castOp.getNumResults() == 1 &&
            isa<pto::TileBufType>(castOp.getResult(0).getType()))
          worklist.push_back(op);
      }
    });

    while (!worklist.empty()) {
      Operation *op = worklist.pop_back_val();
      SmallVector<Value, 4> operands(op->getOperands());
      op->erase();
      for (Value operand : operands) {
        if (auto *defOp = operand.getDefiningOp()) {
          if (!defOp->use_empty())
            continue;
          if (isa<pto::TReshapeOp, pto::MaterializeTileOp, pto::AllocTileOp>(defOp)) {
            worklist.push_back(defOp);
          } else if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(defOp)) {
            if (castOp.getNumOperands() == 1 && castOp.getNumResults() == 1 &&
                isa<pto::TileBufType>(castOp.getResult(0).getType()))
              worklist.push_back(defOp);
          }
        }
      }
    }


eraseDeadAllocTileOps(func);
}
};
Expand Down
Loading