Skip to content

Optimized rocm specific multicast transpose kernel#586

Open
alextmagro wants to merge 3 commits into
devfrom
multicasttranspose_opt
Open

Optimized rocm specific multicast transpose kernel#586
alextmagro wants to merge 3 commits into
devfrom
multicasttranspose_opt

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

Optimizes the multi_cast_transpose kernel for rocm.

Benchmark Results

Qwen has 128 experts, DS has 256 experts. Benchmarked with shapes derived from MBS={1,2,4}

Balanced Experts

Benchmark Base us Base TiB/s Opt us Opt TiB/s % Peak Speedup
128exp/4096cols/MBS1 668 0.75 140 3.58 49.2% 4.77x
128exp/4096cols/MBS2 1147 0.86 248 3.99 54.8% 4.63x
128exp/4096cols/MBS4 2388 0.82 478 4.12 56.6% 5.00x
128exp/1536cols/MBS1 281 0.67 54.8 3.44 47.3% 5.13x
128exp/1536cols/MBS2 592 0.63 101 3.67 50.4% 5.86x
128exp/1536cols/MBS4 934 0.79 191 3.86 53.1% 4.89x
128exp/3072cols/MBS1 548 0.69 102 3.71 51.0% 5.37x
128exp/3072cols/MBS2 904 0.82 204 3.65 50.2% 4.43x
128exp/3072cols/MBS4 1763 0.84 371 3.98 54.7% 4.75x
256exp/7168cols/MBS1 1621 0.56 289 3.13 43.0% 5.61x
256exp/7168cols/MBS2 2251 0.78 519 3.39 46.6% 4.34x
256exp/7168cols/MBS4 4289 0.81 861 4.03 55.4% 4.98x
256exp/2048cols/MBS1 489 0.53 83.5 3.08 42.3% 5.86x
256exp/2048cols/MBS2 765 0.66 147 3.41 46.9% 5.20x
256exp/2048cols/MBS4 1251 0.79 267 3.71 51.0% 4.69x
256exp/4096cols/MBS1 949 0.55 168 3.07 42.2% 5.65x
256exp/4096cols/MBS2 1235 0.81 315 3.20 44.0% 3.92x
256exp/4096cols/MBS4 2417 0.82 533 3.72 51.1% 4.53x

Skewed routing

Benchmark Base us Base TiB/s Opt us Opt TiB/s % Peak Speedup
128exp/4096cols/MBS1 810 0.62 137 3.69 50.7% 5.91x
128exp/4096cols/MBS2 1335 0.74 273 3.63 49.9% 4.89x
128exp/4096cols/MBS4 2660 0.74 513 3.83 52.6% 5.19x
128exp/1536cols/MBS1 349 0.54 53.2 3.55 48.8% 6.56x
128exp/1536cols/MBS2 635 0.59 87.6 4.24 58.3% 7.25x
128exp/1536cols/MBS4 1108 0.67 202 3.66 50.3% 5.49x
128exp/3072cols/MBS1 644 0.59 94.1 4.01 55.1% 6.84x
128exp/3072cols/MBS2 1136 0.65 210 3.54 48.7% 5.41x
128exp/3072cols/MBS4 2049 0.72 389 3.79 52.1% 5.27x
256exp/7168cols/MBS1 1676 0.54 310 2.91 40.0% 5.41x
256exp/7168cols/MBS2 2732 0.64 496 3.55 48.8% 5.51x
256exp/7168cols/MBS4 4674 0.74 928 3.74 51.4% 5.04x
256exp/2048cols/MBS1 635 0.41 87.9 2.93 40.3% 7.22x
256exp/2048cols/MBS2 933 0.54 141 3.56 48.9% 6.62x
256exp/2048cols/MBS4 1560 0.64 272 3.65 50.2% 5.74x
256exp/4096cols/MBS1 1039 0.50 185 2.79 38.3% 5.62x
256exp/4096cols/MBS2 1644 0.61 292 3.44 47.3% 5.63x
256exp/4096cols/MBS4 2728 0.73 534 3.71 51.0% 5.11x

Performance Summary

Average speedup (balanced): 5.0x
Average speedup (skewed): 5.8x
Average % peak (balanced): 49.4%
Average % peak (skewed): 49.0%

Change Summary

  • 512 threads/block (WPT=16) vs upstream's 128 (WPT=4) — 4x more threads for latency hiding
  • Non-temporal stores for both outputs via NTVec — upstream uses regular Vec::store_to which pollutes L2 (CDNA4 L2 is write-allocate)
  • Packed FP8 intrinsics via rocm_pack_4xfloat8 — 2 v_cvt_pk_fp8_f32 per 4 values vs upstream's scalar OType(scale * x) casts
  • Fused amax into the pack loop with tree reduction — upstream has a separate serial amax pass
  • 128-bit vectorized loads (LOAD_SZ=16 for BF16, NVEC_IN=8) — upstream uses 64-bit (LOAD_SZ=8, NVEC_IN=4)
  • kMCTMaxTensors=256 — single kernel launch for up to 256 experts. Upstream limited to 64 (CUDA 4KB kernarg limit doesn't apply on AMD)
  • Edge-tile bounds checking — handles any row count with pad-16 alignment. Interior tiles run the fast path, edge tiles are predicated per-row
  • Binary search for tensor lookup — O(log N) vs upstream's O(N) linear scan
  • rocm_block_reduce_max with rocm_atomicMaxFloat — uses atomicMax on int-reinterpreted float (single instruction) vs upstream's CAS loop
  • Column-major tile orderingtile_m = local_bid % tiles_m for L2 input locality
Rejected/skipped optimizations (click to expand)
  • WPT=8 — -25-33%, VGPR pressure from local_t[8][4]
  • WPT=32 — -10% for MBS1, 2 blocks/CU limit
  • Wave64 — -12-24%, 2x smem + narrower stores + 8 iterations
  • Cached stores for output_c — neutral
  • IS_EDGE template split — -20-26%, 2x launch overhead
  • STORE_SZ grouping — -10% Qwen3, more launches hurt
  • Persistent kernel — moot after kMCTMaxTensors=256
  • Inline ASM for FP8 pack — -2%, compiler already optimal
  • ds_read_b64_tr_b8 — 128 LDS calls vs 4 syncthreads
  • Row/column cascade — edge-tile bounds checking sufficient
  • Precompute tensor lookup in smem — binary search already <1%

@alextmagro alextmagro added ci-level 1 CI test level 1 ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels May 14, 2026
HIP_CHECK(hipEventCreate(&stop));

nvte_multi_cast_transpose(num_experts, nvte_in.data(), nvte_out.data(), stream);
HIP_CHECK(hipStreamSynchronize(stream));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is synchronize needed here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't here. I had copied over the cast_transpose benchmark and edited that, but since we don't have rtc we don't need the pre-call and sync.

@alextmagro alextmagro requested a review from ipanfilo May 18, 2026 22:52
Comment thread transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh
Comment thread transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh
Comment on lines +150 to +232
} else {
#pragma unroll
for (int iter = 0; iter < NUM_ITERS; iter++) {
const int i1 = tidy + iter * WARPS_PER_TILE;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < NVEC_OUT; i2++) {
const int row = row_base + i1 * NVEC_OUT + i2;
const int col = col_base + j1 * NVEC_IN;

IVec in;
OVecC out_c;

if (row < num_rows) {
in.load(&input[row * row_length + col]);
} else {
#pragma unroll
for (int j2 = 0; j2 < NVEC_IN; j2++) in.val[j2] = IType(0);
}

#ifdef HAS_PACK_4xFLOAT8
if constexpr (sizeof(OType) == 1) {
#pragma unroll
for (int j2 = 0; j2 < NVEC_IN; j2 += 4) {
const float v0 = static_cast<float>(in.val[j2]);
const float v1 = (j2+1 < NVEC_IN) ? static_cast<float>(in.val[j2+1]) : 0.0f;
const float v2 = (j2+2 < NVEC_IN) ? static_cast<float>(in.val[j2+2]) : 0.0f;
const float v3 = (j2+3 < NVEC_IN) ? static_cast<float>(in.val[j2+3]) : 0.0f;
if (row < num_rows)
amax = fmaxf(amax, fmaxf(fmaxf(fabsf(v0), fabsf(v1)), fmaxf(fabsf(v2), fabsf(v3))));
uint32_t packed = rocm_pack_4xfloat8<OType>(
v0 * scale, v1 * scale, v2 * scale, v3 * scale);
uint8_t *bytes = reinterpret_cast<uint8_t *>(&packed);
#pragma unroll
for (int k = 0; k < 4 && j2 + k < NVEC_IN; k++) {
out_c.val[j2 + k] = reinterpret_cast<OType &>(bytes[k]);
local_t[j2 + k][iter].val[i2] = out_c.val[j2 + k];
}
}
} else
#endif
{
#pragma unroll
for (int j2 = 0; j2 < NVEC_IN; j2++) {
const float v = static_cast<float>(in.val[j2]);
if (row < num_rows)
amax = fmaxf(amax, fabsf(v));
const OType o = static_cast<OType>(v * scale);
out_c.val[j2] = o;
local_t[j2][iter].val[i2] = o;
}
}

if (row < num_rows)
out_c.nt_store(&output_c[row * row_length + col]);
}
}

#pragma unroll
for (int j2 = 0; j2 < NVEC_IN; j2++) {
#pragma unroll
for (int iter = 0; iter < NUM_ITERS; iter++) {
smem[tidx][tidy + iter * WARPS_PER_TILE] = local_t[j2][iter];
}
__syncthreads();
#pragma unroll
for (int iter = 0; iter < NUM_ITERS; iter++) {
const int i1 = tidx;
const int j1 = tidy + iter * WARPS_PER_TILE;
const int row = row_base + i1 * NVEC_OUT;
const int col = col_base + j1 * NVEC_IN + j2;
if (row + NVEC_OUT <= num_rows) {
smem[j1][i1].nt_store(&output_t[col * num_rows + row]);
} else if (row < num_rows) {
for (int k = 0; k < NVEC_OUT && row + k < num_rows; k++)
output_t[col * num_rows + row + k] = smem[j1][i1].val[k];
}
}
if (j2 + 1 < NVEC_IN) {
__syncthreads();
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a fair amount of duplication between the interior-tile and row-edge-tile paths. Most of the load/cast/pack/amax/local_t logic is identical, with the edge path only adding row predicates and partial-vector stores. Is there some way we can re-factor some of the duplicated parts into device inline helpers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants