diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 85970756c..6d03d5120 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -105,6 +105,7 @@ LogicalResult validateVPTOEmissionIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +std::unique_ptr createPTOInsertVMemBarPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index bcc165674..669725890 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -888,4 +888,28 @@ def VPTOPtrCastCleanup "mlir::memref::MemRefDialect"]; } +def PTOInsertVMemBar + : Pass<"pto-insert-v-membar", "mlir::func::FuncOp"> { + let summary = "Insert V-pipeline memory barriers between tile ops with " + "overlapping UB addresses"; + let description = [{ + After `pto-expand-tile-op` (and before `pto-inline-libcall`), each tile + op is a `func.call` whose tile operands trace back to `pto.alloc_tile` + with a constant UB address. This pass analyzes RAW/WAW/WAR memory + dependencies between V-pipeline tile ops whose `alloc_tile` address + ranges overlap, and inserts `pto.mem_bar` barriers so that inlined + `vsts`/`vlds` do not reorder across the dependency. + + Only V-pipeline tile ops (`getPipe() == PIPE_V`) are considered; Cube + synchronization and cross-pipeline sync remain the responsibility of + `pto-insert-sync` and friends. Only intra-block, intra-iteration + dependencies are handled. + }]; + let constructor = "mlir::pto::createPTOInsertVMemBarPass()"; + let dependentDialects = ["mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect"]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index a7059674d..a314310c3 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/PTOInsertSync.cpp PTOInjectBarrierAllSync.cpp InsertSync/InsertSyncDebug.cpp + PTOInsertVMemBar.cpp PTOViewToMemref.cpp PTOValidateIntToPtrUses.cpp ExpandTileOp.cpp diff --git a/lib/PTO/Transforms/PTOInsertVMemBar.cpp b/lib/PTO/Transforms/PTOInsertVMemBar.cpp new file mode 100644 index 000000000..79628a1d6 --- /dev/null +++ b/lib/PTO/Transforms/PTOInsertVMemBar.cpp @@ -0,0 +1,321 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOInsertVMemBar.cpp - V-pipeline mem_bar insertion ----------------===// +// +// After tile ops are lowered to the structured `pto.tsub` / `pto.tcolexpandmul` +// form (and before `pto-expand-tile-op` expands them to `func.call`), each +// V-pipeline tile op carries: +// - `OpPipeInterface` → getPipe() == PIPE_V +// - `DestinationStyleOpInterface` → precise read (dpsInput) / write (dpsInit) +// operand sets +// - tile operands that trace back to `pto.alloc_tile addr = ` with a +// constant UB address. +// +// This pass analyzes RAW / WAW / WAR memory dependencies between V-pipeline +// tile ops whose `alloc_tile` address ranges overlap, and inserts `pto.mem_bar` +// barriers before the dependent op so that the later inlined `vsts` / `vlds` +// cannot reorder across the dependency. +// +// Scope: +// - Only V-pipeline tile ops (PIPE_V). Cube / cross-pipe sync stays with +// `pto-insert-sync` and friends. +// - Only intra-block, intra-iteration dependencies (same enclosing block, +// not across `scf.for` iterations). +// - If a `pto.mem_bar` already sits between two dependent ops, no new +// barrier is inserted. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include +#include + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir { +namespace pto { + #define GEN_PASS_DEF_PTOINSERTVMEMBAR + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +/// A half-open UB address range [base, base + size) for one tile operand. +struct AddrRange { + uint64_t base = 0; + uint64_t size = 0; + bool valid = false; // false if address could not be resolved to a constant +}; + +static bool overlaps(const AddrRange &a, const AddrRange &b) { + if (!a.valid || !b.valid) + return true; // conservative: unknown address => may alias + return a.base < b.base + b.size && b.base < a.base + a.size; +} + +/// Resolve the element size in bytes for a tile element type. +static uint64_t getElementBytes(Type elemTy) { + if (auto intTy = dyn_cast(elemTy)) + return intTy.getWidth() / 8; + if (auto floatTy = dyn_cast(elemTy)) + return floatTy.getWidth() / 8; + // bf16 is a FloatType (width 16) handled above; fall back to 4 bytes. + return 4; +} + +/// Compute the byte size covered by a tile buffer from its TileBufType shape. +static uint64_t getTileByteSize(TileBufType tileTy) { + auto shape = tileTy.getShape(); + uint64_t elems = 1; + for (int64_t dim : shape) { + if (dim <= 0) + return 0; // dynamic => unknown size; caller treats as may-alias + elems *= static_cast(dim); + } + return elems * getElementBytes(tileTy.getElementType()); +} + +/// Trace a tile operand back to its `pto.alloc_tile` and resolve its UB +/// address range. Returns std::nullopt if the operand is not a tile_buf or +/// its address is not a constant. +static std::optional resolveTileAddrRange(Value tileVal) { + auto tileTy = dyn_cast(tileVal.getType()); + if (!tileTy) + return std::nullopt; + + AddrRange range; + range.size = getTileByteSize(tileTy); + + // Trace back through defining op (alloc_tile produces the tile_buf). + auto alloc = tileVal.getDefiningOp(); + if (!alloc) { + // Not directly from alloc_tile (e.g. a subview / bind_tile). Be + // conservative: treat as unknown address that may alias anything. + range.valid = false; + return range; + } + + Value addrVal = alloc.getAddr(); + if (!addrVal) { + range.valid = false; + return range; + } + if (auto cInt = addrVal.getDefiningOp()) { + range.base = static_cast(cInt.value()); + range.valid = true; + return range; + } + range.valid = false; + return range; +} + +/// One memory access of a tile op: an address range plus whether it is a +/// read or a write. +struct TileAccess { + AddrRange range; + bool isWrite; +}; + +/// Collect the read and write tile accesses of a V-pipeline tile op via its +/// PTO_DpsInitOpInterface: the `getDpsInits()` operands are writes, all other +/// tile_buf operands are reads. +static SmallVector collectTileAccesses(Operation *op) { + SmallVector accesses; + auto dps = dyn_cast(op); + if (!dps) + return accesses; + + // Build the set of write (init) operands. + SmallPtrSet writeOperands; + MutableOperandRange inits = dps.getDpsInitsMutable(); + for (OpOperand &initOp : inits) { + if (isa(initOp.get().getType())) + writeOperands.insert(&initOp); + } + + for (OpOperand &operand : op->getOpOperands()) { + if (!isa(operand.get().getType())) + continue; + bool isWrite = writeOperands.count(&operand) > 0; + if (auto range = resolveTileAddrRange(operand.get())) + accesses.push_back({*range, isWrite}); + } + return accesses; +} + +static bool isVMemBarrier(Operation *op) { return isa(op); } + +/// Returns true if a pto.mem_bar exists in (opA, opB) (open on opA, closed on +/// opB) within the same block, i.e. after opA and strictly before opB. +static bool hasBarrierBetween(Block::iterator opA, Operation *opB) { + Block *block = opB->getBlock(); + if (!block) + return false; + for (auto it = std::next(opA); it != block->end(); ++it) { + if (&*it == opB) + return false; + if (isVMemBarrier(&*it)) + return true; + } + return false; +} + +/// Pick the MemBarKind for a dependency between opA and opB. +/// RAW (opA write -> opB read) : VST_VLD +/// WAW (opA write -> opB write) : VST_VST +/// WAR (opA read -> opB write) : VLD_VST +static MemBarKind classifyDep(const SmallVector &aAcc, + const SmallVector &bAcc) { + bool raw = false, waw = false, war = false; + for (const auto &wa : aAcc) { + if (!wa.isWrite) + continue; + for (const auto &rb : bAcc) { + if (!overlaps(wa.range, rb.range)) + continue; + if (rb.isWrite) + waw = true; + else + raw = true; + } + } + for (const auto &ra : aAcc) { + if (ra.isWrite) + continue; + for (const auto &wb : bAcc) { + if (!wb.isWrite) + continue; + if (overlaps(ra.range, wb.range)) + war = true; + } + } + // Prefer VST_VLD (RAW) first, then VST_VST (WAW), then VLD_VST (WAR). + if (raw) + return MemBarKind::VST_VLD; + if (waw) + return MemBarKind::VST_VST; + if (war) + return MemBarKind::VLD_VST; + return MemBarKind::VV_ALL; +} + +static bool hasOverlapDep(const SmallVector &aAcc, + const SmallVector &bAcc) { + for (const auto &a : aAcc) + for (const auto &b : bAcc) { + if (!overlaps(a.range, b.range)) + continue; + // any cross-op pair that is not (read,read) is a dependency + if (a.isWrite || b.isWrite) + return true; + } + return false; +} + +struct PTOInsertVMemBarPass + : public mlir::pto::impl::PTOInsertVMemBarBase { + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // Collect insertion points first, then build, so we don't mutate IR while + // iterating. Each entry: . + SmallVector, 16> insertions; + + // Gather V-pipeline tile ops across the function, grouped by their + // enclosing block. Dependencies are only intra-block / intra-iteration. + struct VTileOp { + Operation *op; + SmallVector accesses; + }; + llvm::DenseMap> byBlock; + func.walk([&](Operation *op) { + if (!isa(op)) + return; + auto pipeIface = cast(op); + if (pipeIface.getPipe() != pto::PIPE::PIPE_V) + return; + if (!isa(op)) + return; + auto accesses = collectTileAccesses(op); + if (accesses.empty()) + return; + Block *block = op->getBlock(); + if (!block) + return; + byBlock[block].push_back({op, std::move(accesses)}); + }); + + for (auto &kv : byBlock) { + SmallVector &vOps = kv.second; + if (vOps.size() < 2) + continue; + + // For each pair (i < j) with an overlapping dependency and no existing + // barrier between them, schedule a mem_bar before vOps[j]. + for (unsigned j = 1; j < vOps.size(); ++j) { + MemBarKind kind = MemBarKind::VV_ALL; + bool anyDep = false; + for (unsigned i = 0; i < j; ++i) { + if (!hasOverlapDep(vOps[i].accesses, vOps[j].accesses)) + continue; + // Only the nearest preceding op with a dependency matters for + // ordering: if a barrier already exists between vOps[i] and + // vOps[j], later dependences on vOps[j] are still possible from + // ops after vOps[i]; but a barrier between any i(j) - 1; i >= 0; --i) { + if (!hasOverlapDep(vOps[i].accesses, vOps[j].accesses)) + continue; + if (hasBarrierBetween(Block::iterator(vOps[i].op), vOps[j].op)) + break; // an existing barrier already guards this opB from prior + kind = classifyDep(vOps[i].accesses, vOps[j].accesses); + insertions.emplace_back(vOps[j].op, kind); + break; + } + } + } + + OpBuilder builder(func.getContext()); + for (auto [opB, kind] : insertions) { + builder.setInsertionPoint(opB); + auto attr = pto::MemBarAttr::get(func.getContext(), kind); + builder.create(opB->getLoc(), attr); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOInsertVMemBarPass() { + return std::make_unique(); +} diff --git a/test/lit/vpto/vmembar_raw_waw_overlap.pto b/test/lit/vpto/vmembar_raw_waw_overlap.pto new file mode 100644 index 000000000..d7ece7d7f --- /dev/null +++ b/test/lit/vpto/vmembar_raw_waw_overlap.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( ptoas --pto-arch=a5 --pto-backend=vpto --pto-level=level3 --enable-v-membar \ +// RUN: --mlir-print-ir-after=pto-insert-v-membar %s -o %t 2>&1 || true ) | FileCheck %s +// +// Two `pto.tcolexpandmul` write UB tiles at addr 1536 (t) and 1792 (%0); a +// following `pto.tsub` reads both (1536, 1792) and writes 1536 (rot_lo, which +// reuses t's address). This creates RAW (1536, 1792) and WAW (1536) hazards. +// With --enable-v-membar the pass must insert a `pto.mem_bar "VST_VLD"` before +// pto.tsub so the inlined vsts/vlds cannot reorder. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmembar_raw_waw_overlap(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c768_i64 = arith.constant 768 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %c1792_i64 = arith.constant 1792 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c64_index = arith.constant 64 : index + %c1024_index = arith.constant 1024 : index + %c128_index = arith.constant 128 : index + %c16_index = arith.constant 16 : index + + %k_view = pto.make_tensor_view %arg0, shape = [%c16_index, %c1024_index], strides = [%c1024_index, %c1_index] {layout = #pto.layout} : !pto.tensor_view + %cos_view = pto.make_tensor_view %arg1, shape = [%c1_index, %c64_index], strides = [%c128_index, %c1_index] {layout = #pto.layout} : !pto.tensor_view + + %k_lo__tile = pto.alloc_tile addr = %c1024_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + %k_lo_pview = pto.partition_view %k_view, offsets = [%arg2, %c0_index], sizes = [%c1_index, %c64_index] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xf32> + pto.tload ins(%k_lo_pview : !pto.partition_tensor_view<1x64xf32>) outs(%k_lo__tile : !pto.tile_buf) + %k_hi__tile = pto.alloc_tile addr = %c1280_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + %k_hi_pview = pto.partition_view %k_view, offsets = [%arg2, %c64_index], sizes = [%c1_index, %c64_index] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xf32> + pto.tload ins(%k_hi_pview : !pto.partition_tensor_view<1x64xf32>) outs(%k_hi__tile : !pto.tile_buf) + %cos_lo__tile = pto.alloc_tile addr = %c0_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + %cos_lo_pview = pto.partition_view %cos_view, offsets = [%c0_index, %c0_index], sizes = [%c1_index, %c64_index] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xf32> + pto.tload ins(%cos_lo_pview : !pto.partition_tensor_view<1x64xf32>) outs(%cos_lo__tile : !pto.tile_buf) + %sin_lo__tile = pto.alloc_tile addr = %c256_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + %sin_lo_pview = pto.partition_view %cos_view, offsets = [%c0_index, %c64_index], sizes = [%c1_index, %c64_index] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xf32> + pto.tload ins(%sin_lo_pview : !pto.partition_tensor_view<1x64xf32>) outs(%sin_lo__tile : !pto.tile_buf) + + %t__tile = pto.alloc_tile addr = %c1536_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + pto.tcolexpandmul ins(%k_lo__tile, %cos_lo__tile : !pto.tile_buf, !pto.tile_buf) outs(%t__tile : !pto.tile_buf) + %0 = pto.alloc_tile addr = %c1792_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + pto.tcolexpandmul ins(%k_hi__tile, %sin_lo__tile : !pto.tile_buf, !pto.tile_buf) outs(%0 : !pto.tile_buf) + %rot_lo__tile = pto.alloc_tile addr = %c1536_i64 valid_row = %c1_index valid_col = %c64_index : !pto.tile_buf + pto.tsub ins(%t__tile, %0 : !pto.tile_buf, !pto.tile_buf) outs(%rot_lo__tile : !pto.tile_buf) + return + } +} + +// CHECK-LABEL: IR Dump After PTOInsertVMemBar +// CHECK: pto.tcolexpandmul +// CHECK: pto.tcolexpandmul +// CHECK: pto.mem_bar "VST_VLD" +// CHECK: pto.tsub diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 6443f1ca7..2a26cad27 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -317,6 +317,12 @@ static llvm::cl::opt enableInsertSync("enable-insert-sync", llvm::cl::desc("Enable automatic synchronization insertion pass"), llvm::cl::init(false)); +static llvm::cl::opt enableVMembar( + "enable-v-membar", + llvm::cl::desc("Insert V-pipeline mem_bar between tile ops with " + "overlapping UB addresses (VST/VLD RAW/WAW/WAR)"), + llvm::cl::init(false)); + static llvm::cl::opt enableBufidSync( "enable-bufid_sync", llvm::cl::desc("Enable A5 buffer-id synchronization insertion pass"), @@ -1558,6 +1564,13 @@ static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, enableOpFusion && moduleArchAttr && moduleArchAttr.getValue() == "a5"; pto::ExpandTileOpOptions expandOpts = resolveExpandTileOpOptions(argc, argv); + // Insert V-pipeline mem_bar while tile ops are still in their structured + // `pto.tsub` / `pto.tcolexpandmul` form (before pto-expand-tile-op turns + // them into func.call), so OpPipeInterface + DestinationStyleOpInterface + + // constant alloc_tile addresses are all available. + if (enableVMembar) + kernelModulePM.addNestedPass( + pto::createPTOInsertVMemBarPass()); kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); kernelModulePM.addPass(pto::createPTOInlineLibCallPass());