Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc363fa
add MX scale pre-swizzling for gfx1250
matthiasdiener Apr 27, 2026
a6ca3af
switch to mxfp4
matthiasdiener Apr 27, 2026
d1ee5bd
tensile-like implementation
matthiasdiener Apr 28, 2026
d1647ee
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener Apr 29, 2026
1fff6d9
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
d714038
gfx1250 swizzle_xor changes for FP4
matthiasdiener May 1, 2026
76ca4b1
change line endings to unix, trim trailing whitespace
matthiasdiener May 1, 2026
81a0a27
Merge branch 'mdiener/swizzle_xor-1250' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
2991bcf
fix arch
matthiasdiener May 1, 2026
8ceb89c
[WIP] e2e gemm test, not working yet
matthiasdiener May 1, 2026
167d2eb
fix for gfx1250
matthiasdiener May 3, 2026
5d46537
k-tile
matthiasdiener May 3, 2026
313a6b7
extend tests
matthiasdiener May 3, 2026
2a8eeb5
remove ifdef
matthiasdiener May 3, 2026
c37a781
undo BLK32_UE8M0_32_8_EXT
matthiasdiener May 4, 2026
5d2d38f
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 5, 2026
f093f64
Revert "change line endings to unix, trim trailing whitespace"
matthiasdiener May 5, 2026
ecbffea
Revert "gfx1250 swizzle_xor changes for FP4"
matthiasdiener May 5, 2026
33fca6e
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 13, 2026
b55a538
address review comments
matthiasdiener May 13, 2026
398cc3c
cleanups
matthiasdiener May 13, 2026
384d590
re-add scale swizzle hooks in GEMM paths for gfx1250
matthiasdiener May 13, 2026
5c5a902
cleanups
matthiasdiener May 13, 2026
2c05ec5
arch fixes
matthiasdiener May 14, 2026
5552b09
more test fixes gfx1250
matthiasdiener May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources
test_multi_unpadding.cu
test_causal_softmax.cu
test_swap_first_dims.cu
test_swizzle.cu
../test_common.cu)
if(USE_CUDA)
list(APPEND test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu)
test_cast_float8blockwise.cu)
else()
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu
Expand Down
310 changes: 293 additions & 17 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand Down Expand Up @@ -318,6 +319,14 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
else if (use_fp8) {
atol = 1e-3;
rtol = std::max(rtol, 1e-2);
#ifdef __HIP_PLATFORM_AMD__
// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major >= 12 && type == DType::kBFloat16) {
rtol = std::max(rtol, 5e-2);
}
#endif
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
Expand Down Expand Up @@ -496,6 +505,66 @@ void performTest(const TestParams& params) {
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

//perform the reference gemm on GPU (before swizzle, which modifies scales in-place)
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

#ifdef __HIP_PLATFORM_AMD__
// On gfx1250+, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (use_mxfp8 && prop.major >= 12) {
auto swizzle_scales = [](test::Tensor &t, bool rowwise) {
using namespace transformer_engine;
void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr) return;
const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();
size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d];
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);
if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));
};
// Swizzle only the scale directions that actually exist on the tensor.
if (!a_colwise) swizzle_scales(A, true);
if (a_colwise) swizzle_scales(A, false);
if (!b_colwise) swizzle_scales(B, true);
if (b_colwise) swizzle_scales(B, false);
}
#endif

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
Expand All @@ -517,23 +586,6 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the reference gemm on GPU
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// check if error message happens in running
(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand Down Expand Up @@ -793,4 +845,228 @@ TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
}
}

// ============================================================================
// End-to-end MXFP8 GEMM test with pre-swizzled scales
//
// Verifies that the full pipeline works:
// 1. Create MXFP8 FP8 tensors with random data + scales
// 2. Run a reference GEMM (using un-swizzled scales)
// 3. Swizzle the scales via nvte_swizzle_scaling_factors
// 4. Run the actual hipBLASlt GEMM
// 5. Compare results
// ============================================================================

// Helper: swizzle the MXFP8 scale_inv of a test::Tensor in-place.
// Allocates a temp device buffer, swizzles into it, copies back.
static void swizzle_tensor_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;

void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr)
return;

const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();

size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) {
num_scales *= scale_shape.data[d];
}

// Allocate temp buffer for swizzled output
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));

// Build TensorWrapper pair for the swizzle call
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);

if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}

nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Copy swizzled scales back over the original
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));

// Mark tensor as having swizzled scales
t.set_with_gemm_swizzled_scales(true);
}

// Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout)
// A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major
// Scales are E8M0, one per group of 32 elements along K.
__global__ void mxfp8_gemm_ref_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a second mxfp8 reference kernel?

const test::fp8e4m3 *a_data, const uint8_t *a_scale, size_t a_scale_ld,
const test::fp8e4m3 *b_data, const uint8_t *b_scale, size_t b_scale_ld,
test::bf16 *d_data,
size_t M, size_t K, size_t N) {
const size_t i = blockIdx.y * blockDim.y + threadIdx.y;
const size_t j = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= M || j >= N)
return;

float acc = 0.0f;

for (size_t kk = 0; kk < K; kk++) {
size_t kc = kk / 32;
float a_sinv = exp2f(static_cast<float>(a_scale[i * a_scale_ld + kc]) - 127.0f);
float b_sinv = exp2f(static_cast<float>(b_scale[j * b_scale_ld + kc]) - 127.0f);
float a_val = static_cast<float>(a_data[i * K + kk]);
float b_val = static_cast<float>(b_data[j * K + kk]);
acc += a_sinv * a_val * b_sinv * b_val;
}

d_data[i + j * M] = static_cast<test::bf16>(acc);
}

struct MxGemmParams {
size_t m, k, n;
};

class MxGemmSwizzleGfx1250TestSuite
: public ::testing::TestWithParam<MxGemmParams> {};

TEST_P(MxGemmSwizzleGfx1250TestSuite, TestMxfp8GemmE2E) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is we must swizzle scales for gfx1250. I think ideally we would fuse this with the existing mxfp8 GEMM tests -- pre-1250 we don't swizzle, 1250+ we do.

using namespace transformer_engine;
using namespace test;

const auto &p = GetParam();
const size_t M = p.m;
const size_t K = p.k;
const size_t N = p.n;

cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));

// This test validates the MX scale pre-swizzle -> GEMM pipeline on gfx1250+.
// Non-swizzle MXFP8 GEMMs are already covered by GEMMTestSuite.
if (prop.major < 12) {
GTEST_SKIP() << "MX scale pre-swizzle GEMM requires gfx1250+";
}

// TN layout: A is [M, K], B is [N, K]
const bool transa = true;
const bool transb = false;

Tensor A("A", std::vector<size_t>{M, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING);
Tensor B("B", std::vector<size_t>{N, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING);
Tensor D("D", std::vector<size_t>{N, M}, DType::kBFloat16);
Tensor RefD("RefD", std::vector<size_t>{N, M}, DType::kBFloat16);
Tensor bias;
Tensor pre_gelu_out;

fillUniform(&A);
fillUniform(&B);

// Override scales with values in [120,127] so layout errors are detectable.
// Default random [0,127] produces mostly tiny scales (2^(-127)..2^0),
// making the test insensitive to permutation errors.
{
auto fill_discriminating_scales = [](void *scale_ptr, size_t count) {
std::vector<uint8_t> h(count);
std::mt19937 rng(42);
std::uniform_int_distribution<uint8_t> dist(120, 127);
for (size_t i = 0; i < count; i++)
h[i] = dist(rng);
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, h.data(), count, cudaMemcpyHostToDevice));
};
auto a_sh = A.rowwise_scale_inv_shape();
auto b_sh = B.rowwise_scale_inv_shape();
fill_discriminating_scales(A.rowwise_scale_inv_dptr(), a_sh.data[0] * a_sh.data[1]);
fill_discriminating_scales(B.rowwise_scale_inv_dptr(), b_sh.data[0] * b_sh.data[1]);
}

// GPU reference with un-swizzled (compact) scales
const auto a_scale_shape = A.rowwise_scale_inv_shape();
const auto b_scale_shape = B.rowwise_scale_inv_shape();

std::cout << " A_scale shape: [" << a_scale_shape.data[0] << ", " << a_scale_shape.data[1]
<< "], B_scale shape: [" << b_scale_shape.data[0] << ", " << b_scale_shape.data[1]
<< "]" << std::endl;

{
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
mxfp8_gemm_ref_kernel<<<grid, block>>>(
static_cast<const fp8e4m3 *>(A.rowwise_dptr()),
static_cast<const uint8_t *>(A.rowwise_scale_inv_dptr()),
a_scale_shape.data[1],
static_cast<const fp8e4m3 *>(B.rowwise_dptr()),
static_cast<const uint8_t *>(B.rowwise_scale_inv_dptr()),
b_scale_shape.data[1],
static_cast<bf16 *>(RefD.rowwise_dptr()),
M, K, N);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
}

// Swizzle scales to K-tiled layout for hipBLASlt BLK32_UE8M0_32_8_EXT on gfx1250.
// Layout: {M, K_scale}.reshape({M, K_scale/4, 4}).permute({1,0,2})
// dst(m,k) = (k/4)*M*4 + m*4 + (k%4)
swizzle_tensor_scales(A, true);
swizzle_tensor_scales(B, true);

// Run actual GEMM
size_t workspace_size = 134217728; // 128MB
Tensor Workspace("Workspace", std::vector<size_t>{workspace_size}, DType::kByte);

nvte_cublas_gemm(A.data(), B.data(), D.data(),
bias.data(), pre_gelu_out.data(),
transa, transb,
/*grad=*/false,
Workspace.data(),
/*accumulate=*/false,
/*use_split_accumulator=*/false,
prop.multiProcessorCount,
0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Compare
D.to_cpu();
RefD.to_cpu();

// MXFP8 accumulation errors grow with K due to different reduction orders
// between hardware and reference kernels.
const double atol = 5e-2 + K * 2e-4;
const double rtol = 1.5e-2;
compareResults("D", D, RefD.rowwise_cpu_dptr<bf16>(), true, atol, rtol);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxGemmSwizzleGfx1250TestSuite,
::testing::Values(
MxGemmParams{32, 128, 16},
MxGemmParams{64, 128, 32},
MxGemmParams{128, 128, 64},
MxGemmParams{64, 256, 32},
MxGemmParams{128, 384, 64},
MxGemmParams{256, 512, 128},
MxGemmParams{512, 1024, 256},
MxGemmParams{1024, 2048, 128},
MxGemmParams{4096, 8192, 64}
),
[](const testing::TestParamInfo<MxGemmSwizzleGfx1250TestSuite::ParamType> &info) {
return "M" + std::to_string(info.param.m) +
"_K" + std::to_string(info.param.k) +
"_N" + std::to_string(info.param.n);
});

#endif // __HIP_PLATFORM_AMD__
Loading
Loading