Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 70 additions & 65 deletions kernels/sort/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,105 +10,110 @@

constexpr int kInputCount = 131072;
constexpr int kTopK = 2048;
constexpr int kTileSize = 256; // tile register size (16×16)
constexpr int kTileSize = 256;
constexpr int kNumTiles = kInputCount / kTileSize;
constexpr int kNumBuckets = 256;

// ============================================================================
// Tile type aliases
// ============================================================================

using TileU32 = Tile<Location::Vec, uint32_t, 16, 16, BLayout::RowMajor>;
using TileU16 = Tile<Location::Vec, uint16_t, 1, kTileSize, BLayout::RowMajor>;
using TileU32 = Tile<Location::Vec, uint32_t, 1, kNumBuckets, BLayout::RowMajor>;

using InputGM = GlobalTensor<uint16_t, Shape<1,1,1,1,kTileSize>, Stride<1,1,1,kTileSize,1>>;
using HistGM = GlobalTensor<uint32_t, Shape<1,1,1,1,kNumBuckets>, Stride<1,1,1,kNumBuckets,1>>;

// ============================================================================
// Phase 1: SIMT Extract high8 histogram (per-bucket)
//
// Grid: <<<1, 256, 1>>> (1 block, 256 lanes, one lane per bucket 0..255)
// Each lane (bucket = lane_id):
// loops over ALL kInputCount elements from global memory
// counts how many have high8 == bucket
// writes count to dst[lane_id].
// THISTOGRAM comes from the compiler's tileop-api (pulled in via pto_tileop.hpp):
// void THISTOGRAM(dst, src, Idx, ByteId)
// This kernel depends on the upstream fix for the off-by-one operand numbering
// in tileop-api's THISTOGRAM template (template_asm.hpp) — tracked upstream in
// LinxISA/llvm-project. No self-contained copy is kept here on purpose; the
// tileop-api is owned by the compiler.
// ============================================================================

template <typename tile_shape_out>
void __vec__ ExtractHigh8Hist_Vec_RowMajor(
typename tile_shape_out::TileDType __out__ dst,
const uint16_t* __in__ src)
{
size_t bucket = blkv_get_index_y();
typename tile_shape_out::DType count = 0;
// ============================================================================
// Phase 1: Build high8 cumulative histogram via THISTOGRAM (Byte1, no filter)
// ============================================================================

for (unsigned int i = 0; i < kInputCount; i++) {
uint16_t val = src[i];
uint8_t high8 = static_cast<uint8_t>(val >> 8);
if (high8 == bucket) {
count += 1;
}
inline void build_high8_histogram(uint16_t* input, uint32_t hist[256]) {
TileU16 inputTile;
TileU16 dummyIdx;
TileU32 histTile, accumTile;

TEXPANDSCALAR(accumTile, (uint32_t)0);
TEXPANDSCALAR(dummyIdx, (uint16_t)0);

for (int t = 0; t < kNumTiles; t++) {
InputGM gm(input + t * kTileSize);
TCOPYIN(inputTile, gm);
THISTOGRAM(histTile, inputTile, dummyIdx, 1);
TADD(accumTile, accumTile, histTile);
}

blkv_get_tile_ptr(dst)[bucket] = count;
HistGM histGM(hist);
TCOPYOUT(histGM, accumTile);
}

// ============================================================================
// Phase 3: SIMT Extract low8 histogram for kth_bin elements (per-bucket)
// Phase 3: Build low8 cumulative histogram for kth_bin via THISTOGRAM (Byte0, filtered)
// ============================================================================

template <typename tile_shape_out>
void __vec__ ExtractLow8HistForKthBin_Vec_RowMajor(
typename tile_shape_out::TileDType __out__ dst,
const uint16_t* __in__ src,
uint16_t kth_bin)
{
size_t bucket = blkv_get_index_y();
typename tile_shape_out::DType count = 0;

for (unsigned int i = 0; i < kInputCount; i++) {
uint16_t val = src[i];
uint8_t high8 = static_cast<uint8_t>(val >> 8);
if (high8 == kth_bin) {
uint8_t low8 = static_cast<uint8_t>(val & 0xFF);
if (low8 == bucket) {
count += 1;
}
}
inline void build_low8_histogram(uint16_t* input, uint16_t kth_bin, uint32_t hist[256]) {
TileU16 inputTile, idxTile;
TileU32 histTile, accumTile;

TEXPANDSCALAR(accumTile, (uint32_t)0);
TEXPANDSCALAR(idxTile, kth_bin);

for (int t = 0; t < kNumTiles; t++) {
InputGM gm(input + t * kTileSize);
TCOPYIN(inputTile, gm);
THISTOGRAM(histTile, inputTile, idxTile, 0);
TADD(accumTile, accumTile, histTile);
}

blkv_get_tile_ptr(dst)[bucket] = count;
HistGM histGM(hist);
TCOPYOUT(histGM, accumTile);
}

// ============================================================================
// Wrapper launch helpers
// Phase 5: Masked select (scalar — TCMP doesn't support u16, TCAST crashes
// the compiler's ClockHands pass. The histogram phases above are the
// compute-intensive part and use THISTOGRAM tileblock ops.)
// ============================================================================

template <typename tile_shape_out>
void ExtractHigh8Hist_Impl(tile_shape_out& dst, const uint16_t* src) {
ExtractHigh8Hist_Vec_RowMajor<tile_shape_out>
<<<1, 256, 1>>>(dst.data(), src);
}

template <typename tile_shape_out>
void ExtractLow8HistForKthBin_Impl(tile_shape_out& dst, const uint16_t* src,
uint16_t kth_bin) {
ExtractLow8HistForKthBin_Vec_RowMajor<tile_shape_out>
<<<1, 256, 1>>>(dst.data(), src, kth_bin);
inline void masked_select(const uint16_t* input, int kth_bin, int low8_boundary, uint16_t* output) {
for (int i = 0; i < kInputCount; i++) {
uint16_t val = input[i];
uint8_t high8 = static_cast<uint8_t>(val >> 8);
int low8 = static_cast<int>(val & 0xFF);
int include = (high8 > kth_bin) ||
((high8 == kth_bin) & (low8 >= low8_boundary));
output[i] = include ? val : 0;
}
}

// ============================================================================
// Scalar helper: prefix scan to find kth_bin and remaining count
// Scalar: find kth_bin from cumulative histogram
// C[k] = count of elements with byte value 0..k (output of THISTOGRAM)
// Scans from bin 255 down to find where cumulative-from-top first reaches k
// ============================================================================

static int find_kth_bin(const uint32_t hist[256], int k, int& need_from_kth) {
// Scan from highest bin down — we want the largest k elements.
uint64_t cumsum = 0;
static void find_kth_bin_from_cumulative(const uint32_t C[256], int k,
int& kth_bin, int& need_from_kth) {
uint32_t total = C[255];
for (int b = 255; b >= 0; b--) {
cumsum += hist[b];
if (cumsum >= static_cast<uint64_t>(k)) {
need_from_kth = k - static_cast<int>(cumsum - hist[b]);
return b;
uint32_t count_ge = (b > 0) ? (total - C[b-1]) : total;
if (count_ge >= (uint32_t)k) {
kth_bin = b;
need_from_kth = k - (int)(total - C[b]);
return;
}
}
need_from_kth = 0;
return 0;
kth_bin = 0;
need_from_kth = k;
}

#endif // TOPK_HPP
Loading