-
Notifications
You must be signed in to change notification settings - Fork 29
add MXFP8 pre-swizzling for gfx1250 GEMM #568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
matthiasdiener
wants to merge
25
commits into
dev
Choose a base branch
from
mdiener/mxfp8-swizzle
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
bc363fa
add MX scale pre-swizzling for gfx1250
matthiasdiener a6ca3af
switch to mxfp4
matthiasdiener d1ee5bd
tensile-like implementation
matthiasdiener d1647ee
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener 1fff6d9
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener d714038
gfx1250 swizzle_xor changes for FP4
matthiasdiener 76ca4b1
change line endings to unix, trim trailing whitespace
matthiasdiener 81a0a27
Merge branch 'mdiener/swizzle_xor-1250' into mdiener/mxfp8-swizzle
matthiasdiener 2991bcf
fix arch
matthiasdiener 8ceb89c
[WIP] e2e gemm test, not working yet
matthiasdiener 167d2eb
fix for gfx1250
matthiasdiener 5d46537
k-tile
matthiasdiener 313a6b7
extend tests
matthiasdiener 2a8eeb5
remove ifdef
matthiasdiener c37a781
undo BLK32_UE8M0_32_8_EXT
matthiasdiener 5d2d38f
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener f093f64
Revert "change line endings to unix, trim trailing whitespace"
matthiasdiener ecbffea
Revert "gfx1250 swizzle_xor changes for FP4"
matthiasdiener 33fca6e
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener b55a538
address review comments
matthiasdiener 398cc3c
cleanups
matthiasdiener 384d590
re-add scale swizzle hooks in GEMM paths for gfx1250
matthiasdiener 5c5a902
cleanups
matthiasdiener 2c05ec5
arch fixes
matthiasdiener 5552b09
more test fixes gfx1250
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| #include <gtest/gtest.h> | ||
| #include <transformer_engine/cast.h> | ||
| #include <transformer_engine/gemm.h> | ||
| #include <transformer_engine/swizzle.h> | ||
| #include <transformer_engine/transformer_engine.h> | ||
| #include "../test_common.h" | ||
|
|
||
|
|
@@ -318,6 +319,14 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool | |
| else if (use_fp8) { | ||
| atol = 1e-3; | ||
| rtol = std::max(rtol, 1e-2); | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Relax for gfx1250 | ||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
| if (prop.major >= 12 && type == DType::kBFloat16) { | ||
| rtol = std::max(rtol, 5e-2); | ||
| } | ||
| #endif | ||
| } | ||
| else if (type == DType::kBFloat16) { | ||
| //relax for certain prime number TN gemm | ||
|
|
@@ -496,6 +505,66 @@ void performTest(const TestParams& params) { | |
| #endif | ||
| Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); | ||
|
|
||
| //perform the reference gemm on GPU (before swizzle, which modifies scales in-place) | ||
| Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); | ||
| Tensor RefPreGeluOut; | ||
|
|
||
| if (params.use_gelu) { | ||
| RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); | ||
| } | ||
|
|
||
| run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||
| params, | ||
| A, | ||
| B, | ||
| params.use_bias ? &bias : nullptr, | ||
| D, | ||
| RefD, | ||
| params.use_gelu ? &RefPreGeluOut : nullptr); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // On gfx1250+, hipBLASLt MXFP8 kernels expect pre-swizzled scales. | ||
| if (use_mxfp8 && prop.major >= 12) { | ||
| auto swizzle_scales = [](test::Tensor &t, bool rowwise) { | ||
| using namespace transformer_engine; | ||
| void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() | ||
| : t.columnwise_scale_inv_dptr(); | ||
| if (!scale_ptr) return; | ||
| const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() | ||
| : t.columnwise_scale_inv_shape(); | ||
| const NVTEShape data_shape = rowwise ? t.rowwise_shape() | ||
| : t.columnwise_shape(); | ||
| size_t num_scales = 1; | ||
| for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d]; | ||
| uint8_t *d_tmp = nullptr; | ||
| NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); | ||
| TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); | ||
| TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); | ||
| output_tw.set_with_gemm_swizzled_scales(true); | ||
| if (rowwise) { | ||
| input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||
| input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||
| output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||
| output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||
| } else { | ||
| input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||
| input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||
| output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||
| output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||
| } | ||
| nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); | ||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
| NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); | ||
| NVTE_CHECK_CUDA(cudaFree(d_tmp)); | ||
| }; | ||
| // Swizzle only the scale directions that actually exist on the tensor. | ||
| if (!a_colwise) swizzle_scales(A, true); | ||
| if (a_colwise) swizzle_scales(A, false); | ||
| if (!b_colwise) swizzle_scales(B, true); | ||
| if (b_colwise) swizzle_scales(B, false); | ||
| } | ||
| #endif | ||
|
|
||
| //perform the gemm in GPU | ||
| nvte_cublas_gemm(A.data(), | ||
| B.data(), | ||
|
|
@@ -517,23 +586,6 @@ void performTest(const TestParams& params) { | |
| pre_gelu_out.to_cpu(); | ||
| } | ||
|
|
||
| //perform the reference gemm on GPU | ||
| Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); | ||
| Tensor RefPreGeluOut; | ||
|
|
||
| if (params.use_gelu) { | ||
| RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); | ||
| } | ||
|
|
||
| run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||
| params, | ||
| A, | ||
| B, | ||
| params.use_bias ? &bias : nullptr, | ||
| D, | ||
| RefD, | ||
| params.use_gelu ? &RefPreGeluOut : nullptr); | ||
|
|
||
| // check if error message happens in running | ||
| (void)cudaDeviceSynchronize(); | ||
| auto err = cudaGetLastError(); | ||
|
|
@@ -793,4 +845,228 @@ TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { | |
| } | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // End-to-end MXFP8 GEMM test with pre-swizzled scales | ||
| // | ||
| // Verifies that the full pipeline works: | ||
| // 1. Create MXFP8 FP8 tensors with random data + scales | ||
| // 2. Run a reference GEMM (using un-swizzled scales) | ||
| // 3. Swizzle the scales via nvte_swizzle_scaling_factors | ||
| // 4. Run the actual hipBLASlt GEMM | ||
| // 5. Compare results | ||
| // ============================================================================ | ||
|
|
||
| // Helper: swizzle the MXFP8 scale_inv of a test::Tensor in-place. | ||
| // Allocates a temp device buffer, swizzles into it, copies back. | ||
| static void swizzle_tensor_scales(test::Tensor &t, bool rowwise) { | ||
| using namespace transformer_engine; | ||
|
|
||
| void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() | ||
| : t.columnwise_scale_inv_dptr(); | ||
| if (!scale_ptr) | ||
| return; | ||
|
|
||
| const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() | ||
| : t.columnwise_scale_inv_shape(); | ||
| const NVTEShape data_shape = rowwise ? t.rowwise_shape() | ||
| : t.columnwise_shape(); | ||
|
|
||
| size_t num_scales = 1; | ||
| for (size_t d = 0; d < scale_shape.ndim; d++) { | ||
| num_scales *= scale_shape.data[d]; | ||
| } | ||
|
|
||
| // Allocate temp buffer for swizzled output | ||
| uint8_t *d_tmp = nullptr; | ||
| NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); | ||
|
|
||
| // Build TensorWrapper pair for the swizzle call | ||
| TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); | ||
| TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); | ||
| output_tw.set_with_gemm_swizzled_scales(true); | ||
|
|
||
| if (rowwise) { | ||
| input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||
| input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||
| output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||
| output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||
| } else { | ||
| input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||
| input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||
| output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||
| output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||
| } | ||
|
|
||
| nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); | ||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
|
|
||
| // Copy swizzled scales back over the original | ||
| NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); | ||
| NVTE_CHECK_CUDA(cudaFree(d_tmp)); | ||
|
|
||
| // Mark tensor as having swizzled scales | ||
| t.set_with_gemm_swizzled_scales(true); | ||
| } | ||
|
|
||
| // Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout) | ||
| // A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major | ||
| // Scales are E8M0, one per group of 32 elements along K. | ||
| __global__ void mxfp8_gemm_ref_kernel( | ||
| const test::fp8e4m3 *a_data, const uint8_t *a_scale, size_t a_scale_ld, | ||
| const test::fp8e4m3 *b_data, const uint8_t *b_scale, size_t b_scale_ld, | ||
| test::bf16 *d_data, | ||
| size_t M, size_t K, size_t N) { | ||
| const size_t i = blockIdx.y * blockDim.y + threadIdx.y; | ||
| const size_t j = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
|
||
| if (i >= M || j >= N) | ||
| return; | ||
|
|
||
| float acc = 0.0f; | ||
|
|
||
| for (size_t kk = 0; kk < K; kk++) { | ||
| size_t kc = kk / 32; | ||
| float a_sinv = exp2f(static_cast<float>(a_scale[i * a_scale_ld + kc]) - 127.0f); | ||
| float b_sinv = exp2f(static_cast<float>(b_scale[j * b_scale_ld + kc]) - 127.0f); | ||
| float a_val = static_cast<float>(a_data[i * K + kk]); | ||
| float b_val = static_cast<float>(b_data[j * K + kk]); | ||
| acc += a_sinv * a_val * b_sinv * b_val; | ||
| } | ||
|
|
||
| d_data[i + j * M] = static_cast<test::bf16>(acc); | ||
| } | ||
|
|
||
| struct MxGemmParams { | ||
| size_t m, k, n; | ||
| }; | ||
|
|
||
| class MxGemmSwizzleGfx1250TestSuite | ||
| : public ::testing::TestWithParam<MxGemmParams> {}; | ||
|
|
||
| TEST_P(MxGemmSwizzleGfx1250TestSuite, TestMxfp8GemmE2E) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is we must swizzle scales for gfx1250. I think ideally we would fuse this with the existing mxfp8 GEMM tests -- pre-1250 we don't swizzle, 1250+ we do. |
||
| using namespace transformer_engine; | ||
| using namespace test; | ||
|
|
||
| const auto &p = GetParam(); | ||
| const size_t M = p.m; | ||
| const size_t K = p.k; | ||
| const size_t N = p.n; | ||
|
|
||
| cudaDeviceProp prop; | ||
| NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); | ||
|
|
||
| // This test validates the MX scale pre-swizzle -> GEMM pipeline on gfx1250+. | ||
| // Non-swizzle MXFP8 GEMMs are already covered by GEMMTestSuite. | ||
| if (prop.major < 12) { | ||
| GTEST_SKIP() << "MX scale pre-swizzle GEMM requires gfx1250+"; | ||
| } | ||
|
|
||
| // TN layout: A is [M, K], B is [N, K] | ||
| const bool transa = true; | ||
| const bool transb = false; | ||
|
|
||
| Tensor A("A", std::vector<size_t>{M, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); | ||
| Tensor B("B", std::vector<size_t>{N, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); | ||
| Tensor D("D", std::vector<size_t>{N, M}, DType::kBFloat16); | ||
| Tensor RefD("RefD", std::vector<size_t>{N, M}, DType::kBFloat16); | ||
| Tensor bias; | ||
| Tensor pre_gelu_out; | ||
|
|
||
| fillUniform(&A); | ||
| fillUniform(&B); | ||
|
|
||
| // Override scales with values in [120,127] so layout errors are detectable. | ||
| // Default random [0,127] produces mostly tiny scales (2^(-127)..2^0), | ||
| // making the test insensitive to permutation errors. | ||
| { | ||
| auto fill_discriminating_scales = [](void *scale_ptr, size_t count) { | ||
| std::vector<uint8_t> h(count); | ||
| std::mt19937 rng(42); | ||
| std::uniform_int_distribution<uint8_t> dist(120, 127); | ||
| for (size_t i = 0; i < count; i++) | ||
| h[i] = dist(rng); | ||
| NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, h.data(), count, cudaMemcpyHostToDevice)); | ||
| }; | ||
| auto a_sh = A.rowwise_scale_inv_shape(); | ||
| auto b_sh = B.rowwise_scale_inv_shape(); | ||
| fill_discriminating_scales(A.rowwise_scale_inv_dptr(), a_sh.data[0] * a_sh.data[1]); | ||
| fill_discriminating_scales(B.rowwise_scale_inv_dptr(), b_sh.data[0] * b_sh.data[1]); | ||
| } | ||
|
|
||
| // GPU reference with un-swizzled (compact) scales | ||
| const auto a_scale_shape = A.rowwise_scale_inv_shape(); | ||
| const auto b_scale_shape = B.rowwise_scale_inv_shape(); | ||
|
|
||
| std::cout << " A_scale shape: [" << a_scale_shape.data[0] << ", " << a_scale_shape.data[1] | ||
| << "], B_scale shape: [" << b_scale_shape.data[0] << ", " << b_scale_shape.data[1] | ||
| << "]" << std::endl; | ||
|
|
||
| { | ||
| dim3 block(16, 16); | ||
| dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); | ||
| mxfp8_gemm_ref_kernel<<<grid, block>>>( | ||
| static_cast<const fp8e4m3 *>(A.rowwise_dptr()), | ||
| static_cast<const uint8_t *>(A.rowwise_scale_inv_dptr()), | ||
| a_scale_shape.data[1], | ||
| static_cast<const fp8e4m3 *>(B.rowwise_dptr()), | ||
| static_cast<const uint8_t *>(B.rowwise_scale_inv_dptr()), | ||
| b_scale_shape.data[1], | ||
| static_cast<bf16 *>(RefD.rowwise_dptr()), | ||
| M, K, N); | ||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
| } | ||
|
|
||
| // Swizzle scales to K-tiled layout for hipBLASlt BLK32_UE8M0_32_8_EXT on gfx1250. | ||
| // Layout: {M, K_scale}.reshape({M, K_scale/4, 4}).permute({1,0,2}) | ||
| // dst(m,k) = (k/4)*M*4 + m*4 + (k%4) | ||
| swizzle_tensor_scales(A, true); | ||
| swizzle_tensor_scales(B, true); | ||
|
|
||
| // Run actual GEMM | ||
| size_t workspace_size = 134217728; // 128MB | ||
| Tensor Workspace("Workspace", std::vector<size_t>{workspace_size}, DType::kByte); | ||
|
|
||
| nvte_cublas_gemm(A.data(), B.data(), D.data(), | ||
| bias.data(), pre_gelu_out.data(), | ||
| transa, transb, | ||
| /*grad=*/false, | ||
| Workspace.data(), | ||
| /*accumulate=*/false, | ||
| /*use_split_accumulator=*/false, | ||
| prop.multiProcessorCount, | ||
| 0); | ||
|
|
||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
|
|
||
| // Compare | ||
| D.to_cpu(); | ||
| RefD.to_cpu(); | ||
|
|
||
| // MXFP8 accumulation errors grow with K due to different reduction orders | ||
| // between hardware and reference kernels. | ||
| const double atol = 5e-2 + K * 2e-4; | ||
| const double rtol = 1.5e-2; | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<bf16>(), true, atol, rtol); | ||
| } | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P( | ||
| OperatorTest, | ||
| MxGemmSwizzleGfx1250TestSuite, | ||
| ::testing::Values( | ||
| MxGemmParams{32, 128, 16}, | ||
| MxGemmParams{64, 128, 32}, | ||
| MxGemmParams{128, 128, 64}, | ||
| MxGemmParams{64, 256, 32}, | ||
| MxGemmParams{128, 384, 64}, | ||
| MxGemmParams{256, 512, 128}, | ||
| MxGemmParams{512, 1024, 256}, | ||
| MxGemmParams{1024, 2048, 128}, | ||
| MxGemmParams{4096, 8192, 64} | ||
| ), | ||
| [](const testing::TestParamInfo<MxGemmSwizzleGfx1250TestSuite::ParamType> &info) { | ||
| return "M" + std::to_string(info.param.m) + | ||
| "_K" + std::to_string(info.param.k) + | ||
| "_N" + std::to_string(info.param.n); | ||
| }); | ||
|
|
||
| #endif // __HIP_PLATFORM_AMD__ | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a second mxfp8 reference kernel?