diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..20d6919cc 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu + test_swizzle.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu - test_swizzle.cu) + test_cast_float8blockwise.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 85f183bf7..5b2d78bf7 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "../test_common.h" @@ -318,6 +319,14 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool else if (use_fp8) { atol = 1e-3; rtol = std::max(rtol, 1e-2); +#ifdef __HIP_PLATFORM_AMD__ + // Relax for gfx1250 + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (prop.major >= 12 && type == DType::kBFloat16) { + rtol = std::max(rtol, 5e-2); + } +#endif } else if (type == DType::kBFloat16) { //relax for certain prime number TN gemm @@ -496,6 +505,66 @@ void performTest(const TestParams& params) { #endif Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); + //perform the reference gemm on GPU (before swizzle, which modifies scales in-place) + Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); + Tensor RefPreGeluOut; + + if (params.use_gelu) { + RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); + } + + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D, + RefD, + params.use_gelu ? &RefPreGeluOut : nullptr); + +#ifdef __HIP_PLATFORM_AMD__ + // On gfx1250+, hipBLASLt MXFP8 kernels expect pre-swizzled scales. + if (use_mxfp8 && prop.major >= 12) { + auto swizzle_scales = [](test::Tensor &t, bool rowwise) { + using namespace transformer_engine; + void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() + : t.columnwise_scale_inv_dptr(); + if (!scale_ptr) return; + const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() + : t.columnwise_scale_inv_shape(); + const NVTEShape data_shape = rowwise ? t.rowwise_shape() + : t.columnwise_shape(); + size_t num_scales = 1; + for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d]; + uint8_t *d_tmp = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + if (rowwise) { + input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } else { + input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); + NVTE_CHECK_CUDA(cudaFree(d_tmp)); + }; + // Swizzle only the scale directions that actually exist on the tensor. + if (!a_colwise) swizzle_scales(A, true); + if (a_colwise) swizzle_scales(A, false); + if (!b_colwise) swizzle_scales(B, true); + if (b_colwise) swizzle_scales(B, false); + } +#endif + //perform the gemm in GPU nvte_cublas_gemm(A.data(), B.data(), @@ -517,23 +586,6 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the reference gemm on GPU - Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); - Tensor RefPreGeluOut; - - if (params.use_gelu) { - RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); - } - - run_reference( - params, - A, - B, - params.use_bias ? &bias : nullptr, - D, - RefD, - params.use_gelu ? &RefPreGeluOut : nullptr); - // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -793,4 +845,228 @@ TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { } } +// ============================================================================ +// End-to-end MXFP8 GEMM test with pre-swizzled scales +// +// Verifies that the full pipeline works: +// 1. Create MXFP8 FP8 tensors with random data + scales +// 2. Run a reference GEMM (using un-swizzled scales) +// 3. Swizzle the scales via nvte_swizzle_scaling_factors +// 4. Run the actual hipBLASlt GEMM +// 5. Compare results +// ============================================================================ + +// Helper: swizzle the MXFP8 scale_inv of a test::Tensor in-place. +// Allocates a temp device buffer, swizzles into it, copies back. +static void swizzle_tensor_scales(test::Tensor &t, bool rowwise) { + using namespace transformer_engine; + + void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() + : t.columnwise_scale_inv_dptr(); + if (!scale_ptr) + return; + + const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() + : t.columnwise_scale_inv_shape(); + const NVTEShape data_shape = rowwise ? t.rowwise_shape() + : t.columnwise_shape(); + + size_t num_scales = 1; + for (size_t d = 0; d < scale_shape.ndim; d++) { + num_scales *= scale_shape.data[d]; + } + + // Allocate temp buffer for swizzled output + uint8_t *d_tmp = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); + + // Build TensorWrapper pair for the swizzle call + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + + if (rowwise) { + input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } else { + input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } + + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Copy swizzled scales back over the original + NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); + NVTE_CHECK_CUDA(cudaFree(d_tmp)); + + // Mark tensor as having swizzled scales + t.set_with_gemm_swizzled_scales(true); +} + +// Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout) +// A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major +// Scales are E8M0, one per group of 32 elements along K. +__global__ void mxfp8_gemm_ref_kernel( + const test::fp8e4m3 *a_data, const uint8_t *a_scale, size_t a_scale_ld, + const test::fp8e4m3 *b_data, const uint8_t *b_scale, size_t b_scale_ld, + test::bf16 *d_data, + size_t M, size_t K, size_t N) { + const size_t i = blockIdx.y * blockDim.y + threadIdx.y; + const size_t j = blockIdx.x * blockDim.x + threadIdx.x; + + if (i >= M || j >= N) + return; + + float acc = 0.0f; + + for (size_t kk = 0; kk < K; kk++) { + size_t kc = kk / 32; + float a_sinv = exp2f(static_cast(a_scale[i * a_scale_ld + kc]) - 127.0f); + float b_sinv = exp2f(static_cast(b_scale[j * b_scale_ld + kc]) - 127.0f); + float a_val = static_cast(a_data[i * K + kk]); + float b_val = static_cast(b_data[j * K + kk]); + acc += a_sinv * a_val * b_sinv * b_val; + } + + d_data[i + j * M] = static_cast(acc); +} + +struct MxGemmParams { + size_t m, k, n; +}; + +class MxGemmSwizzleGfx1250TestSuite + : public ::testing::TestWithParam {}; + +TEST_P(MxGemmSwizzleGfx1250TestSuite, TestMxfp8GemmE2E) { + using namespace transformer_engine; + using namespace test; + + const auto &p = GetParam(); + const size_t M = p.m; + const size_t K = p.k; + const size_t N = p.n; + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This test validates the MX scale pre-swizzle -> GEMM pipeline on gfx1250+. + // Non-swizzle MXFP8 GEMMs are already covered by GEMMTestSuite. + if (prop.major < 12) { + GTEST_SKIP() << "MX scale pre-swizzle GEMM requires gfx1250+"; + } + + // TN layout: A is [M, K], B is [N, K] + const bool transa = true; + const bool transb = false; + + Tensor A("A", std::vector{M, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); + Tensor B("B", std::vector{N, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); + Tensor D("D", std::vector{N, M}, DType::kBFloat16); + Tensor RefD("RefD", std::vector{N, M}, DType::kBFloat16); + Tensor bias; + Tensor pre_gelu_out; + + fillUniform(&A); + fillUniform(&B); + + // Override scales with values in [120,127] so layout errors are detectable. + // Default random [0,127] produces mostly tiny scales (2^(-127)..2^0), + // making the test insensitive to permutation errors. + { + auto fill_discriminating_scales = [](void *scale_ptr, size_t count) { + std::vector h(count); + std::mt19937 rng(42); + std::uniform_int_distribution dist(120, 127); + for (size_t i = 0; i < count; i++) + h[i] = dist(rng); + NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, h.data(), count, cudaMemcpyHostToDevice)); + }; + auto a_sh = A.rowwise_scale_inv_shape(); + auto b_sh = B.rowwise_scale_inv_shape(); + fill_discriminating_scales(A.rowwise_scale_inv_dptr(), a_sh.data[0] * a_sh.data[1]); + fill_discriminating_scales(B.rowwise_scale_inv_dptr(), b_sh.data[0] * b_sh.data[1]); + } + + // GPU reference with un-swizzled (compact) scales + const auto a_scale_shape = A.rowwise_scale_inv_shape(); + const auto b_scale_shape = B.rowwise_scale_inv_shape(); + + std::cout << " A_scale shape: [" << a_scale_shape.data[0] << ", " << a_scale_shape.data[1] + << "], B_scale shape: [" << b_scale_shape.data[0] << ", " << b_scale_shape.data[1] + << "]" << std::endl; + + { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + mxfp8_gemm_ref_kernel<<>>( + static_cast(A.rowwise_dptr()), + static_cast(A.rowwise_scale_inv_dptr()), + a_scale_shape.data[1], + static_cast(B.rowwise_dptr()), + static_cast(B.rowwise_scale_inv_dptr()), + b_scale_shape.data[1], + static_cast(RefD.rowwise_dptr()), + M, K, N); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + } + + // Swizzle scales to K-tiled layout for hipBLASlt BLK32_UE8M0_32_8_EXT on gfx1250. + // Layout: {M, K_scale}.reshape({M, K_scale/4, 4}).permute({1,0,2}) + // dst(m,k) = (k/4)*M*4 + m*4 + (k%4) + swizzle_tensor_scales(A, true); + swizzle_tensor_scales(B, true); + + // Run actual GEMM + size_t workspace_size = 134217728; // 128MB + Tensor Workspace("Workspace", std::vector{workspace_size}, DType::kByte); + + nvte_cublas_gemm(A.data(), B.data(), D.data(), + bias.data(), pre_gelu_out.data(), + transa, transb, + /*grad=*/false, + Workspace.data(), + /*accumulate=*/false, + /*use_split_accumulator=*/false, + prop.multiProcessorCount, + 0); + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare + D.to_cpu(); + RefD.to_cpu(); + + // MXFP8 accumulation errors grow with K due to different reduction orders + // between hardware and reference kernels. + const double atol = 5e-2 + K * 2e-4; + const double rtol = 1.5e-2; + compareResults("D", D, RefD.rowwise_cpu_dptr(), true, atol, rtol); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxGemmSwizzleGfx1250TestSuite, + ::testing::Values( + MxGemmParams{32, 128, 16}, + MxGemmParams{64, 128, 32}, + MxGemmParams{128, 128, 64}, + MxGemmParams{64, 256, 32}, + MxGemmParams{128, 384, 64}, + MxGemmParams{256, 512, 128}, + MxGemmParams{512, 1024, 256}, + MxGemmParams{1024, 2048, 128}, + MxGemmParams{4096, 8192, 64} + ), + [](const testing::TestParamInfo &info) { + return "M" + std::to_string(info.param.m) + + "_K" + std::to_string(info.param.k) + + "_N" + std::to_string(info.param.n); + }); + #endif // __HIP_PLATFORM_AMD__ diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 3209d2335..0092a0c62 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -166,3 +166,183 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); + +#ifdef __HIP_PLATFORM_AMD__ + +// MX pre-swizzle test (gfx1250 Tensile 3D layout) +// +// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2}) +// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4) + +// CPU reference for Tensile 3D MX scale pre-swizzle. +// Row-major input [M, K], output is a flat permuted array. +void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; // E8M0 identity: 2^0 = 1.0 + if (m < orig_M && k < orig_K) { + val = h_input[m * orig_K + k]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; + if (m < orig_M && k < orig_K) { + val = h_input[k * orig_M + m]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +static size_t roundup_sz(size_t val, size_t mult) { + return ((val + mult - 1) / mult) * mult; +} + +class MxSwizzleTestSuite + : public ::testing::TestWithParam< + std::tuple, bool>> {}; + +TEST_P(MxSwizzleTestSuite, TestMxSwizzle) { + using namespace transformer_engine; + using namespace test; + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + if (prop.major < 12) { + GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250"; + } + + const auto dims = std::get<0>(GetParam()); + const bool rowwise = std::get<1>(GetParam()); + + // Original (unpadded) scale dimensions + const size_t orig_M = dims.first; + const size_t orig_K = dims.second; + + // Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4 + const size_t M = orig_M; + const size_t K = roundup_sz(orig_K, 4); + + // Allocate host input (unpadded) and fill with random data + const size_t input_size = orig_M * orig_K; + std::unique_ptr h_input(new uint8_t[input_size]); + std::mt19937 rng(42); + for (size_t i = 0; i < input_size; i++) { + h_input[i] = static_cast(rng() % 256); + } + + // Allocate device input + uint8_t *d_input = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size)); + NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice)); + + // Allocate device output (padded size) + const size_t output_size = M * K; + uint8_t *d_output = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size)); + NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size)); + + // Build TensorWrapper for input and output + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + + // Data shape must be consistent with scale shape for validation. + // Scale shapes use padded K; data shapes use unpadded dims + // (kernel derives original_M/K from them). + if (rowwise) { + std::vector data_shape_in = {orig_M, orig_K * 32}; + std::vector data_shape_out = {M, K * 32}; + std::vector scale_shape_in = {M, K}; + std::vector scale_shape_out = {M, K}; + input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } else { + std::vector data_shape_in = {orig_K * 32, orig_M}; + std::vector data_shape_out = {K * 32, M}; + std::vector scale_shape_in = {K, M}; + std::vector scale_shape_out = {K, M}; + input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } + + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Copy output back to host + std::unique_ptr h_output(new uint8_t[output_size]); + NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost)); + + // Compute reference + std::unique_ptr h_ref(new uint8_t[output_size]); + memset(h_ref.get(), 0, output_size); + if (rowwise) { + compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } else { + compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } + + // Compare + compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size); + + cudaFree(d_input); + cudaFree(d_output); +} + +namespace { + +// Scale dimensions (M_scale, K_scale). +// K_scale will be padded to multiple of 4 by the test. +std::vector> mx_scale_dims = { + {4, 4}, // minimal + {8, 4}, // small + {32, 8}, // medium + {64, 16}, // larger + {96, 8}, // non-power-of-2 M + {128, 32}, // big + {256, 64}, // bigger + {512, 128}, // stress inter-tile + {1024, 256}, // large + {4096, 256}, // max stress +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxSwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(mx_scale_dims), + ::testing::Values(true, false) + ), + [](const testing::TestParamInfo& info) { + std::string name = "M" + std::to_string(std::get<0>(info.param).first) + + "_K" + std::to_string(std::get<0>(info.param).second) + + (std::get<1>(info.param) ? "_row" : "_col"); + return name; + }); + +#endif // __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index c634c73fb..d95cd49d8 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -14,6 +14,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -347,9 +348,168 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); } +#ifdef __HIP_PLATFORM_AMD__ +// ============================================================================ +// MX scale pre-swizzle kernel for gfx1250 — K-tiled 3D layout +// +// hipBLASlt Tensile kernels expect scales in a permuted 3D layout that +// groups K_scale into tiles of 4 (= 128 / MXBlock32): +// Tensor({M, K_scale}).pad(K_scale to mult of 4).reshape({M, K_scale/4, 4}).permute({1, 0, 2}) +// +// For source position (m, k) in the [M, K_scale] scale matrix: +// group = k / 4 +// within = k % 4 +// dst = group * (M * 4) + m * 4 + within +// +// Padding: K_scale to multiple of 4. No M padding required. +// Identity padding value: E8M0 127 = 2^0 = 1.0 +// +// Reference: swizzle_mx_scale() in hipblaslt/clients/common/include/testing_matmul.hpp +// ============================================================================ + +constexpr int MX_PRESWIZZLE_GROUP_SIZE = 4; + +// Unified MX scale pre-swizzle kernel for both row-wise and column-wise. +// Iterates only over valid (non-padded) elements; the caller must pre-fill +// the output buffer with identity (127) to handle padding. +// +// kRowwise=true: input is [orig_M, orig_K] row-major +// kRowwise=false: input is [orig_K, orig_M] row-major (column-wise scales) +template +__global__ void __launch_bounds__(256) + swizzle_scaling_mx_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const int padded_M, + const int orig_M, const int orig_K) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total = orig_M * orig_K; + if (idx >= total) return; + + const int m = idx / orig_K; + const int k = idx % orig_K; + + uint8_t val; + if constexpr (kRowwise) { + val = input[idx]; // == input[m * orig_K + k] + } else { + val = input[k * orig_M + m]; + } + + const int group = k / MX_PRESWIZZLE_GROUP_SIZE; + const int within = k % MX_PRESWIZZLE_GROUP_SIZE; + const int dst = group * (padded_M * MX_PRESWIZZLE_GROUP_SIZE) + + m * MX_PRESWIZZLE_GROUP_SIZE + within; + + output[dst] = val; +} + } // namespace +void swizzle_scaling_factors_mx(const Tensor* input, Tensor* output, cudaStream_t stream) { + // Check scaling mode + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING, + "MX pre-swizzle only supports MXFP8 scaling mode (got ", + to_string(input->scaling_mode), ")."); + + // Check tensors + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(!input->with_gemm_swizzled_scales, + "Expected input tensor with scales in compact format."); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output tensor with scales in GEMM swizzled format."); + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + + // Check if scaling factors are non-trivial + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + // Deduce tensor dims + int m{0}, k{0}; + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + + // Check dims -- K-tiled layout requires K_scale padded to multiple of 4 + NVTE_CHECK(k % MX_PRESWIZZLE_GROUP_SIZE == 0, + "Scale K dimension must be padded to multiple of ", MX_PRESWIZZLE_GROUP_SIZE, + ", got ", k, "."); + + // Validate output dimensions match + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.has_data(), + "Output tensor does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + } + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.has_data(), + "Output tensor does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", m * k, + " column-wise scaling factors, but got shape=", + output->columnwise_scale_inv.shape, "."); + } + + const int total = m * k; + constexpr int block = 256; + + // Row-wise swizzle + if (has_rowwise_scale_inv) { + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; + // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding + NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 127, total, stream)); + const int orig_total = original_M * original_K; + const int grid = (orig_total + block - 1) / block; + swizzle_scaling_mx_kernel<<>>( + reinterpret_cast(input->scale_inv.dptr), + reinterpret_cast(output->scale_inv.dptr), + m, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + // Column-wise swizzle + if (has_columnwise_scale_inv) { + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding + NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_scale_inv.dptr, 127, total, stream)); + const int orig_total = original_M * original_K; + const int grid = (orig_total + block - 1) / block; + swizzle_scaling_mx_kernel<<>>( + reinterpret_cast(input->columnwise_scale_inv.dptr), + reinterpret_cast(output->columnwise_scale_inv.dptr), + m, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} +#endif // __HIP_PLATFORM_AMD__ + void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + // On gfx1250, MXFP8 uses the MX pre-swizzle layout (K-tiled, grouped by 4). + if (input->scaling_mode == NVTE_MXFP8_1D_SCALING && cuda::sm_arch() >= 125) { + swizzle_scaling_factors_mx(input, output, stream); + return; + } +#endif // __HIP_PLATFORM_AMD__ + // Check scaling mode const auto& scaling_mode = input->scaling_mode; NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, @@ -667,6 +827,24 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + // On gfx1250, MXFP8 uses the MX pre-swizzle layout. + if (cuda::sm_arch() >= 125) { + bool any_mxfp8 = false; + for (size_t i = 0; i < input.size(); i++) { + if (is_mxfp8_scaling(input[i]->scaling_mode)) { + any_mxfp8 = true; + } + } + if (any_mxfp8) { + for (size_t i = 0; i < input.size(); i++) { + swizzle_scaling_factors_mx(input[i], output[i], stream); + } + return; + } + } +#endif // __HIP_PLATFORM_AMD__ + auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 40121049a..e32a42b1d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -80,12 +80,31 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( "Inverse scale factors need to have an 8-bit data type."); } if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Assume MXFP8 scales are already swizzled if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } +#ifdef USE_ROCM + // On gfx1250, pre-swizzle MXFP8 scales for hipBLASLt + if (transformer_engine::cuda::sm_arch() == 125 && swizzle_scale_ptr) { + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } + output.set_with_gemm_swizzled_scales(true); + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + if (rowwise) { + input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } + } +#endif input.set_with_gemm_swizzled_scales(true); } else if (is_nvfp4) { // Swizzle for NVFP4 #ifdef USE_ROCM @@ -195,7 +214,12 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; - if (is_nvfp4_scaling(scaling_mode)) { + if (is_nvfp4_scaling(scaling_mode) +#ifdef USE_ROCM + || (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING + && transformer_engine::cuda::sm_arch() == 125) +#endif + ) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); workspace_size = workspace_size - lhs_scale_size - rhs_scale_size; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 6898ce387..7a54728c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -244,13 +244,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans config.set_use_split_accumulator(use_split_accumulator); config.set_sm_count(num_math_sms); -#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; -#endif auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { -#ifndef USE_ROCM // Optionally swizzle the scaling factors auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa); auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb); @@ -259,6 +256,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales)); swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales)); +#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { @@ -532,7 +530,6 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } -#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; @@ -542,6 +539,7 @@ std::optional> te_general_grouped_gemm( swizzled_scale_inverses_list.emplace_back( multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb)); +#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (transformer_engine::cuda::sm_arch() >= 100) { diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 4ad57bbf1..d9929c93e 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -13,6 +13,7 @@ #include "common.h" #include "common/common.h" +#include "common/util/cuda_runtime.h" #include "extensions.h" #include "pybind.h" #include "util.h" @@ -55,6 +56,13 @@ std::tuple, std::optional> swizzle_scales_ return {std::nullopt, std::nullopt}; } +#ifdef USE_ROCM + // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling + if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { + return {std::nullopt, std::nullopt}; + } +#endif + // Return early if scales are already swizzled if (tensor.get_with_gemm_swizzled_scales()) { return {std::nullopt, std::nullopt}; @@ -164,6 +172,13 @@ std::optional multi_tensor_swizzle_scales_for_gemm( return std::nullopt; } +#ifdef USE_ROCM + // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling + if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { + return std::nullopt; + } +#endif + // Filter out tensors that already have swizzled scales std::vector tensors_needing_swizzle; for (auto &tensor : tensors) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bb960406d..f1f6d690a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,6 +9,9 @@ #include #include "common.h" +#ifdef USE_ROCM +#include "common/util/cuda_runtime.h" +#endif #include "pybind.h" #include "torch/torch.h" @@ -1104,6 +1107,17 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); #ifdef USE_ROCM + // gfx1250 MX pre-swizzle (Tensile 3D) layout requires M padded to multiple of 4. + if (transformer_engine::cuda::sm_arch() == 125) { + size_t m_dim = numel / last_dim; + size_t k_scale = last_dim / MXFP8_BLOCK_SIZE; + if (!columnwise) { + return {roundup(m_dim, 4), k_scale}; + } else { + return {k_scale, roundup(m_dim, 4)}; + } + } + return !columnwise ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 6588aa6c5..f2310b61f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -9,8 +9,6 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ -#ifndef USE_ROCM - #include #include @@ -37,6 +35,7 @@ std::optional multi_tensor_swizzle_scales_for_gemm(std::vector multi_tensor_swizzle_scales_for_gemm(std::vector