From 2e1b5bf452446581f064eedf322649ebee70dfe4 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Thu, 14 May 2026 01:19:08 +0000 Subject: [PATCH 1/2] Optimized rocm specific multicast transpose kernel --- benchmarks/cpp/CMakeLists.txt | 1 + .../cpp/cast/bench_multi_cast_transpose.cpp | 242 +++++++++++++ benchmarks/cpp/run_benchmarks.sh | 1 + .../common/transpose/multi_cast_transpose.cu | 41 +++ .../transpose/rocm_multi_cast_transpose.cuh | 320 ++++++++++++++++++ 5 files changed, 605 insertions(+) create mode 100644 benchmarks/cpp/cast/bench_multi_cast_transpose.cpp create mode 100644 transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index 6071f9083..3ba034426 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -86,3 +86,4 @@ add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) +add_te_benchmark(bench_multi_cast_transpose cast/bench_multi_cast_transpose.cpp) diff --git a/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp new file mode 100644 index 000000000..92a003ac4 --- /dev/null +++ b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp @@ -0,0 +1,242 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "transformer_engine/transpose_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +#include +#include +#include +#include +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// MoE shapes from Qwen3-235B and DeepSeek-V3 +// Args: {total_tokens, cols, num_experts, top_k, routing_mode} +#define MOE_BALANCED \ + ->Args({4096, 4096, 128, 8, 0}) \ + ->Args({8192, 4096, 128, 8, 0}) \ + ->Args({16384, 4096, 128, 8, 0}) \ + ->Args({4096, 1536, 128, 8, 0}) \ + ->Args({8192, 1536, 128, 8, 0}) \ + ->Args({16384, 1536, 128, 8, 0}) \ + ->Args({4096, 3072, 128, 8, 0}) \ + ->Args({8192, 3072, 128, 8, 0}) \ + ->Args({16384, 3072, 128, 8, 0}) \ + ->Args({4096, 7168, 256, 8, 0}) \ + ->Args({8192, 7168, 256, 8, 0}) \ + ->Args({16384, 7168, 256, 8, 0}) \ + ->Args({4096, 2048, 256, 8, 0}) \ + ->Args({8192, 2048, 256, 8, 0}) \ + ->Args({16384, 2048, 256, 8, 0}) \ + ->Args({4096, 4096, 256, 8, 0}) \ + ->Args({8192, 4096, 256, 8, 0}) \ + ->Args({16384, 4096, 256, 8, 0}) + +#define MOE_SKEWED \ + ->Args({4096, 4096, 128, 8, 1}) \ + ->Args({8192, 4096, 128, 8, 1}) \ + ->Args({16384, 4096, 128, 8, 1}) \ + ->Args({4096, 1536, 128, 8, 1}) \ + ->Args({8192, 1536, 128, 8, 1}) \ + ->Args({16384, 1536, 128, 8, 1}) \ + ->Args({4096, 3072, 128, 8, 1}) \ + ->Args({8192, 3072, 128, 8, 1}) \ + ->Args({16384, 3072, 128, 8, 1}) \ + ->Args({4096, 7168, 256, 8, 1}) \ + ->Args({8192, 7168, 256, 8, 1}) \ + ->Args({16384, 7168, 256, 8, 1}) \ + ->Args({4096, 2048, 256, 8, 1}) \ + ->Args({8192, 2048, 256, 8, 1}) \ + ->Args({16384, 2048, 256, 8, 1}) \ + ->Args({4096, 4096, 256, 8, 1}) \ + ->Args({8192, 4096, 256, 8, 1}) \ + ->Args({16384, 4096, 256, 8, 1}) + +namespace { + +static const uint64_t kRunSeed = std::random_device{}(); +static constexpr size_t kPadMultiple = 16; + +static uint64_t derive_seed(size_t a, size_t b, size_t c, size_t d, size_t e) { + uint64_t h = kRunSeed; + h ^= a; h *= 1099511628211ULL; + h ^= b; h *= 1099511628211ULL; + h ^= c; h *= 1099511628211ULL; + h ^= d; h *= 1099511628211ULL; + h ^= e; h *= 1099511628211ULL; + return h; +} + +static std::vector simulate_topk_balanced( + size_t total_tokens, size_t num_experts, size_t top_k, uint64_t seed) +{ + std::vector counts(num_experts, 0); + std::mt19937_64 gen(seed); + + std::vector experts(num_experts); + std::iota(experts.begin(), experts.end(), 0); + + for (size_t t = 0; t < total_tokens; t++) { + for (size_t k = 0; k < top_k; k++) { + std::uniform_int_distribution dist(k, num_experts - 1); + std::swap(experts[k], experts[dist(gen)]); + counts[experts[k]]++; + } + } + return counts; +} + +static std::vector simulate_topk_skewed( + size_t total_tokens, size_t num_experts, size_t top_k, uint64_t seed) +{ + std::vector counts(num_experts, 0); + std::mt19937_64 gen(seed); + + std::vector weights(num_experts); + for (size_t i = 0; i < num_experts; i++) + weights[i] = 1.0 / std::pow(static_cast(i + 1), 0.7); + + std::shuffle(weights.begin(), weights.end(), gen); + std::discrete_distribution wdist(weights.begin(), weights.end()); + + std::vector used(num_experts, false); + std::vector used_list; + used_list.reserve(top_k); + + for (size_t t = 0; t < total_tokens; t++) { + used_list.clear(); + for (size_t k = 0; k < top_k; k++) { + size_t e; + do { e = wdist(gen); } while (used[e]); + used[e] = true; + used_list.push_back(e); + counts[e]++; + } + for (size_t e : used_list) used[e] = false; + } + return counts; +} + +template +static void BM_MultiCastTranspose(benchmark::State &state) { + const size_t total_tokens = state.range(0); + const size_t cols = state.range(1); + const size_t num_experts = state.range(2); + const size_t top_k = state.range(3); + const size_t routing_mode = state.range(4); + + uint64_t seed = derive_seed(total_tokens, cols, num_experts, top_k, routing_mode); + + auto counts = (routing_mode == 0) + ? simulate_topk_balanced(total_tokens, num_experts, top_k, seed) + : simulate_topk_skewed(total_tokens, num_experts, top_k, seed); + + size_t min_tok = *std::min_element(counts.begin(), counts.end()); + size_t max_tok = *std::max_element(counts.begin(), counts.end()); + size_t sum_tok = std::accumulate(counts.begin(), counts.end(), size_t(0)); + + DType itype = std::is_same_v ? DType::kFloat32 : + std::is_same_v ? DType::kBFloat16 : + DType::kFloat16; + + std::string pfx = "mct_" + std::to_string(total_tokens) + "_" + + std::to_string(cols) + "_" + std::to_string(num_experts) + + "_" + std::to_string(routing_mode); + + std::vector nvte_in(num_experts), nvte_out(num_experts); + + for (size_t e = 0; e < num_experts; e++) { + size_t rows = ((std::max(counts[e], size_t(1)) + kPadMultiple - 1) + / kPadMultiple) * kPadMultiple; + std::string in_name = pfx + "_in_" + std::to_string(e); + std::string out_name = pfx + "_out_" + std::to_string(e); + + auto &input = TensorCache::get_or_create( + in_name, {rows, cols}, itype, + true, false, NVTE_DELAYED_TENSOR_SCALING, true); + + auto &output = TensorCache::get_or_create( + out_name, {rows, cols}, DType::kFloat8E4M3, + true, true, NVTE_DELAYED_TENSOR_SCALING, false); + + output.set_scale(1.0f); + + nvte_in[e] = input.data(); + nvte_out[e] = output.data(); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + nvte_multi_cast_transpose(num_experts, nvte_in.data(), nvte_out.data(), stream); + HIP_CHECK(hipStreamSynchronize(stream)); + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_multi_cast_transpose(num_experts, nvte_in.data(), nvte_out.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t total_bytes = 0; + for (size_t e = 0; e < num_experts; e++) { + size_t rows = ((std::max(counts[e], size_t(1)) + kPadMultiple - 1) + / kPadMultiple) * kPadMultiple; + total_bytes += rows * cols * sizeof(IType); + total_bytes += rows * cols * sizeof(fp8_e4m3) * 2; + } + set_bytes_processed(state, total_bytes); + + state.counters["experts"] = num_experts; + state.counters["cols"] = cols; + state.counters["avg_tok"] = static_cast(sum_tok) / num_experts; + state.counters["min_tok"] = min_tok; + state.counters["max_tok"] = max_tok; + + HIP_CHECK(hipStreamDestroy(stream)); +} + +} // namespace + +#define REGISTER_MCT(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_MultiCastTranspose, ITYPE) \ + ->Name("BM_MultiCastTranspose/" INAME "_E4M3/moe") \ + MOE_BALANCED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_MultiCastTranspose, ITYPE) \ + ->Name("BM_MultiCastTranspose/" INAME "_E4M3/moe_skewed") \ + MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_MCT(hip_bfloat16, "BF16") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh index 05f7f853e..d8e164e17 100755 --- a/benchmarks/cpp/run_benchmarks.sh +++ b/benchmarks/cpp/run_benchmarks.sh @@ -27,6 +27,7 @@ main() { "bench_dequantize_mxfp8" "bench_gated_mxfp8" "bench_casttranspose" + "bench_multi_cast_transpose" ) FAILED_BENCHMARKS=() diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 33e1c19d8..8e7dbd0d2 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -16,6 +16,10 @@ namespace transformer_engine { +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_multi_cast_transpose.cuh" +#endif // #ifdef __HIP_PLATFORM_AMD__ + namespace { // Parameters to tune @@ -235,6 +239,42 @@ void multi_cast_transpose(const std::vector input_list, std::vector in_ptrs(n); + std::vector out_c_ptrs(n); + std::vector out_t_ptrs(n); + std::vector scale_ptrs(n); + std::vector amax_ptrs(n); + std::vector sinv_ptrs(n); + std::vector rows(n); + std::vector cols(n); + + for (size_t i = 0; i < n; i++) { + in_ptrs[i] = reinterpret_cast(input_list[i]->data.dptr); + out_c_ptrs[i] = reinterpret_cast(output_list[i]->data.dptr); + out_t_ptrs[i] = reinterpret_cast(output_list[i]->columnwise_data.dptr); + scale_ptrs[i] = reinterpret_cast(output_list[i]->scale.dptr); + amax_ptrs[i] = reinterpret_cast(output_list[i]->amax.dptr); + sinv_ptrs[i] = reinterpret_cast(output_list[i]->scale_inv.dptr); + rows[i] = input_list[i]->data.shape[0]; + cols[i] = input_list[i]->data.shape[1]; + } + + rocm_multi_cast_transpose_dispatch(n, in_ptrs.data(), out_c_ptrs.data(), + out_t_ptrs.data(), scale_ptrs.data(), amax_ptrs.data(), sinv_ptrs.data(), rows.data(), + cols.data(), stream); + ); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); + } +#else // Input matrices are divided into tiles // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles const int tile_dim_m = THREADS_PER_WARP * desired_store_size * 8 / typeToNumBits(otype); @@ -328,6 +368,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector +__global__ void __launch_bounds__(ROCM_CT_WARP_SIZE * WARPS_PER_TILE) +rocm_multi_cast_transpose_kernel(RocmMultiCastTransposeArgs args) { + constexpr int NVEC_IN = LOAD_SIZE / sizeof(IType); + constexpr int NVEC_OUT = STORE_SIZE / sizeof(OType); + constexpr int TILE_COLS = ROCM_CT_WARP_SIZE * NVEC_IN; + constexpr int TILE_ROWS = ROCM_CT_WARP_SIZE * NVEC_OUT; + constexpr int NUM_ITERS = ROCM_CT_WARP_SIZE / WARPS_PER_TILE; + + using IVec = NTVec; + using OVecC = NTVec; + using OVecT = NTVec; + + const int tid = threadIdx.x; + const int tidx = tid % ROCM_CT_WARP_SIZE; + const int tidy = tid / ROCM_CT_WARP_SIZE; + const int bid = blockIdx.x; + + int lo = 0, hi = args.num_tensors - 1; + while (lo < hi) { + int mid = (lo + hi) / 2; + if (args.block_range[mid + 1] <= bid) lo = mid + 1; + else hi = mid; + } + + const int tensor_id = lo; + const int local_bid = bid - args.block_range[tensor_id]; + const int num_rows = args.num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + const IType *__restrict__ input = reinterpret_cast(args.input_list[tensor_id]); + OType *__restrict__ output_c = reinterpret_cast(args.output_c_list[tensor_id]); + OType *__restrict__ output_t = reinterpret_cast(args.output_t_list[tensor_id]); + + const float *__restrict__ scale_ptr = reinterpret_cast(args.scale_list[tensor_id]); + float *__restrict__ amax_ptr = reinterpret_cast(args.amax_list[tensor_id]); + float *__restrict__ scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]); + + const int tiles_m = (num_rows + TILE_ROWS - 1) / TILE_ROWS; + const int tile_m = local_bid % tiles_m; + const int tile_n = local_bid / tiles_m; + const int row_base = tile_m * TILE_ROWS; + const int col_base = tile_n * TILE_COLS; + + const bool is_edge = (row_base + TILE_ROWS > num_rows); + + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1.0f; + float amax = 0.0f; + + __shared__ OVecT smem[ROCM_CT_WARP_SIZE][ROCM_CT_WARP_SIZE+1]; + + OVecT local_t[NVEC_IN][NUM_ITERS]; + + if (!is_edge) { +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + const int i1 = tidy + iter * WARPS_PER_TILE; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < NVEC_OUT; i2++) { + const int row = row_base + i1 * NVEC_OUT + i2; + const int col = col_base + j1 * NVEC_IN; + + IVec in; + OVecC out_c; + in.load(&input[row * row_length + col]); + +#ifdef HAS_PACK_4xFLOAT8 + if constexpr (sizeof(OType) == 1) { +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2 += 4) { + const float v0 = static_cast(in.val[j2]); + const float v1 = (j2+1 < NVEC_IN) ? static_cast(in.val[j2+1]) : 0.0f; + const float v2 = (j2+2 < NVEC_IN) ? static_cast(in.val[j2+2]) : 0.0f; + const float v3 = (j2+3 < NVEC_IN) ? static_cast(in.val[j2+3]) : 0.0f; + amax = fmaxf(amax, fmaxf(fmaxf(fabsf(v0), fabsf(v1)), + fmaxf(fabsf(v2), fabsf(v3)))); + uint32_t packed = rocm_pack_4xfloat8( + v0 * scale, v1 * scale, v2 * scale, v3 * scale); + uint8_t *bytes = reinterpret_cast(&packed); +#pragma unroll + for (int k = 0; k < 4 && j2 + k < NVEC_IN; k++) { + out_c.val[j2 + k] = reinterpret_cast(bytes[k]); + local_t[j2 + k][iter].val[i2] = out_c.val[j2 + k]; + } + } + } else +#endif + { +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2++) { + const float v = static_cast(in.val[j2]); + amax = fmaxf(amax, fabsf(v)); + const OType o = static_cast(v * scale); + out_c.val[j2] = o; + local_t[j2][iter].val[i2] = o; + } + } + + out_c.nt_store(&output_c[row * row_length + col]); + } + } + +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2++) { +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + smem[tidx][tidy + iter * WARPS_PER_TILE] = local_t[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + const int i1 = tidx; + const int j1 = tidy + iter * WARPS_PER_TILE; + const int row = row_base + i1 * NVEC_OUT; + const int col = col_base + j1 * NVEC_IN + j2; + smem[j1][i1].nt_store(&output_t[col * num_rows + row]); + } + if (j2 + 1 < NVEC_IN) { + __syncthreads(); + } + } + } else { +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + const int i1 = tidy + iter * WARPS_PER_TILE; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < NVEC_OUT; i2++) { + const int row = row_base + i1 * NVEC_OUT + i2; + const int col = col_base + j1 * NVEC_IN; + + IVec in; + OVecC out_c; + + if (row < num_rows) { + in.load(&input[row * row_length + col]); + } else { +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2++) in.val[j2] = IType(0); + } + +#ifdef HAS_PACK_4xFLOAT8 + if constexpr (sizeof(OType) == 1) { +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2 += 4) { + const float v0 = static_cast(in.val[j2]); + const float v1 = (j2+1 < NVEC_IN) ? static_cast(in.val[j2+1]) : 0.0f; + const float v2 = (j2+2 < NVEC_IN) ? static_cast(in.val[j2+2]) : 0.0f; + const float v3 = (j2+3 < NVEC_IN) ? static_cast(in.val[j2+3]) : 0.0f; + if (row < num_rows) + amax = fmaxf(amax, fmaxf(fmaxf(fabsf(v0), fabsf(v1)), fmaxf(fabsf(v2), fabsf(v3)))); + uint32_t packed = rocm_pack_4xfloat8( + v0 * scale, v1 * scale, v2 * scale, v3 * scale); + uint8_t *bytes = reinterpret_cast(&packed); +#pragma unroll + for (int k = 0; k < 4 && j2 + k < NVEC_IN; k++) { + out_c.val[j2 + k] = reinterpret_cast(bytes[k]); + local_t[j2 + k][iter].val[i2] = out_c.val[j2 + k]; + } + } + } else +#endif + { +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2++) { + const float v = static_cast(in.val[j2]); + if (row < num_rows) + amax = fmaxf(amax, fabsf(v)); + const OType o = static_cast(v * scale); + out_c.val[j2] = o; + local_t[j2][iter].val[i2] = o; + } + } + + if (row < num_rows) + out_c.nt_store(&output_c[row * row_length + col]); + } + } + +#pragma unroll + for (int j2 = 0; j2 < NVEC_IN; j2++) { +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + smem[tidx][tidy + iter * WARPS_PER_TILE] = local_t[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (int iter = 0; iter < NUM_ITERS; iter++) { + const int i1 = tidx; + const int j1 = tidy + iter * WARPS_PER_TILE; + const int row = row_base + i1 * NVEC_OUT; + const int col = col_base + j1 * NVEC_IN + j2; + if (row + NVEC_OUT <= num_rows) { + smem[j1][i1].nt_store(&output_t[col * num_rows + row]); + } else if (row < num_rows) { + for (int k = 0; k < NVEC_OUT && row + k < num_rows; k++) + output_t[col * num_rows + row + k] = smem[j1][i1].val[k]; + } + } + if (j2 + 1 < NVEC_IN) { + __syncthreads(); + } + } + } + + if (amax_ptr != nullptr) { + amax = rocm_block_reduce_max(amax, tidy); + if (tid == 0) { + rocm_atomicMaxFloat(amax_ptr, amax); + } + } + + if (local_bid == 0 && tid == 0 && scale_inv_ptr != nullptr) { + *scale_inv_ptr = __frcp_rn(scale); + } +} + +template +void rocm_multi_cast_transpose_dispatch(size_t num_tensors, const IType *const *input_list, OType *const *output_c_list, + OType *const *output_t_list, const float *const *scale_list, float *const *amax_list, + float *const *scale_inv_list, const size_t *num_rows_list, + const size_t *row_length_list, hipStream_t stream) { + constexpr int WPT = 16; + constexpr int BLK = ROCM_CT_WARP_SIZE * WPT; + constexpr int ISZ = sizeof(IType); + constexpr int OSZ = sizeof(OType); + constexpr int LOAD_SZ = 16; + constexpr int STORE_SZ = 8; + constexpr int TILE_COLS = ROCM_CT_WARP_SIZE * (LOAD_SZ / ISZ); + constexpr int TILE_ROWS = ROCM_CT_WARP_SIZE * (STORE_SZ / OSZ); + + size_t i = 0; + + while (i < num_tensors) { + RocmMultiCastTransposeArgs args; + args.block_range[0] = 0; + + int total_blocks = 0; + int packed = 0; + + while (i < num_tensors && packed < kMCTMaxTensors) { + int rows = num_rows_list[i]; + int cols = row_length_list[i]; + + if (cols % TILE_COLS != 0 || rows == 0) { + if (rows > 0 && cols > 0) { + size_t done = rocm_cast_transpose_dispatch(input_list[i], nullptr, output_c_list[i], + output_t_list[i], scale_list[i], + amax_list[i], scale_inv_list[i], cols, + rows, stream); + if (done < static_cast(rows)) { + size_t rem = rows - done; + hipLaunchKernelGGL( + (rocm_cast_transpose_remainder_kernel), + dim3((rem * cols + 255) / 256), dim3(256), 0, stream, + input_list[i] + done * cols, nullptr, + output_c_list[i] + done * cols, output_t_list[i] + done, + scale_list[i], amax_list[i], scale_inv_list[i], + rem, cols, cols, rows); + } + } + i++; + continue; + } + + int tiles_m = (rows + TILE_ROWS - 1) / TILE_ROWS; + int tiles_n = cols / TILE_COLS; + int tiles = tiles_m * tiles_n; + + args.input_list[packed] = reinterpret_cast(input_list[i]); + args.output_c_list[packed] = reinterpret_cast(output_c_list[i]); + args.output_t_list[packed] = reinterpret_cast(output_t_list[i]); + args.scale_list[packed] = reinterpret_cast(scale_list[i]); + args.amax_list[packed] = amax_list[i]; + args.scale_inv_list[packed] = scale_inv_list[i]; + args.num_rows_list[packed] = rows; + args.row_length_list[packed] = cols; + total_blocks += tiles; + args.block_range[packed + 1] = total_blocks; + packed++; + i++; + } + + if (total_blocks > 0) { + args.num_tensors = packed; + hipLaunchKernelGGL( + (rocm_multi_cast_transpose_kernel), + dim3(total_blocks), dim3(BLK), 0, stream, + args); + } + } +} From b129623d8f231e01d41b20646476529ff8f102a6 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Mon, 18 May 2026 22:51:53 +0000 Subject: [PATCH 2/2] Remove extra sync --- benchmarks/cpp/cast/bench_multi_cast_transpose.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp index 92a003ac4..40c6584a4 100644 --- a/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp +++ b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp @@ -187,8 +187,6 @@ static void BM_MultiCastTranspose(benchmark::State &state) { HIP_CHECK(hipEventCreate(&start)); HIP_CHECK(hipEventCreate(&stop)); - nvte_multi_cast_transpose(num_experts, nvte_in.data(), nvte_out.data(), stream); - HIP_CHECK(hipStreamSynchronize(stream)); warmup_gpu(); for (auto _ : state) { @@ -232,7 +230,7 @@ static void BM_MultiCastTranspose(benchmark::State &state) { ->Unit(benchmark::kMicrosecond) \ ->UseManualTime(); \ BENCHMARK_TEMPLATE(BM_MultiCastTranspose, ITYPE) \ - ->Name("BM_MultiCastTranspose/" INAME "_E4M3/moe_skewed") \ + ->Name("BM_MultiCastTranspose/" INAME "_E4M3/moe_skewed") \ MOE_SKEWED \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime();