-
Notifications
You must be signed in to change notification settings - Fork 29
HipKittens MXFP8 GEMM Support #566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
f9d5ce2
aac5860
c917ed0
3a91321
cc719fe
fcda154
70fba6d
455002e
ba60ef5
f72b7b8
731640a
1960c06
320152e
a280cf7
2a27902
3d7aaf9
824841d
f66f77c
0b6e702
816c752
aaa88d7
e2203c0
7648594
03f675b
f852c22
3b307bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |
|
|
||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | ||
| {32, 128, 16}, | ||
| {256, 256, 256}, | ||
| {768, 3072, 4096}, | ||
| {4096, 16384, 4096}, | ||
| }; | ||
|
|
||
| // A, B, Bias, Gelu, D | ||
|
|
@@ -168,6 +170,20 @@ __global__ void compute_ref_kernel( | |
| } | ||
|
|
||
|
|
||
| constexpr size_t kMXFP8GroupSize = 32; | ||
| constexpr size_t kKTileSize = 128; | ||
|
|
||
| static size_t compute_mxfp8_workspace_size(size_t m, size_t k, size_t n, bool transa, bool transb, size_t base_size) { | ||
| size_t k_iters = k / kKTileSize; | ||
| size_t scale_k = k / kMXFP8GroupSize; | ||
| size_t sa_pk = round_up_to_nearest_multiple(k_iters * m * 4, 256); | ||
| size_t sb_pk = k_iters * n * 4; | ||
| size_t needed = round_up_to_nearest_multiple(sa_pk, 256) + sb_pk; | ||
| if (!transa) needed += round_up_to_nearest_multiple(m * k, 256) + round_up_to_nearest_multiple(m * scale_k, 256) + round_up_to_nearest_multiple(sa_pk, 256); | ||
| if (transb) needed += round_up_to_nearest_multiple(n * k, 256) + round_up_to_nearest_multiple(n * scale_k, 256) + round_up_to_nearest_multiple(sb_pk, 256); | ||
| return std::max(base_size, needed); | ||
| } | ||
|
|
||
| struct TestParams { | ||
| size_t m; | ||
| size_t k; | ||
|
|
@@ -177,6 +193,7 @@ struct TestParams { | |
| bool transa; | ||
| bool transb; | ||
| NVTEScalingMode scaling_mode; | ||
| bool force_hipblaslt; | ||
| }; | ||
|
|
||
|
|
||
|
|
@@ -313,7 +330,7 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool | |
| // relax for certain FP8 gemm with hipblaslt | ||
| if (use_mxfp8) { | ||
| atol = 5e-4; | ||
| rtol = std::max(rtol, 1e-3); | ||
| rtol = std::max(rtol, 5e-3); | ||
| } | ||
| else if (use_fp8) { | ||
| atol = 1e-3; | ||
|
|
@@ -340,9 +357,9 @@ void performTest(const TestParams& params) { | |
|
|
||
| const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype); | ||
| const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; | ||
| const bool use_hipkittens_mxfp8 = use_mxfp8 && !params.force_hipblaslt; | ||
|
|
||
| if (use_mxfp8) | ||
| { | ||
| if (use_mxfp8) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to avoid the skips completely, so split up the test instantiation into non-mxfp8 and mxfp8.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you mean now. We can combine some of the checks to make it easier to read when we are using hipkittens. |
||
| if (!has_fp8) { | ||
| GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types"; | ||
| } | ||
|
|
@@ -352,6 +369,9 @@ void performTest(const TestParams& params) { | |
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| } | ||
| if (use_hipkittens_mxfp8 && (params.m % 256 || params.n % 256 || params.k < 256)) { | ||
| GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256"; | ||
| } | ||
| } | ||
|
|
||
| cudaDeviceProp prop; | ||
|
|
@@ -383,26 +403,18 @@ void performTest(const TestParams& params) { | |
|
|
||
| if (has_fp8) | ||
| { | ||
| bool fp8_supported = (prop.major == 9 && prop.minor >= 4) || prop.major >= 12; | ||
| const bool fp8_supported = (prop.major == 9 && prop.minor >= 4) || prop.major >= 12; | ||
| if (!fp8_supported) { | ||
| GTEST_SKIP() << "FP8 is not supported in current config"; | ||
| } | ||
|
|
||
| if (use_mxfp8) | ||
| { | ||
| bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| if (!mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (isFp8Type(dtype)){ | ||
| GTEST_SKIP() << "MXFP8 with float8 output is not supported"; | ||
| } | ||
| if (params.use_bias) { | ||
| GTEST_SKIP() << "MXFP8 GEMM with bias is not supported"; | ||
| } | ||
| const bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| if (use_mxfp8 && !mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
|
|
||
| if (params.use_gelu && !fp8_gelu_fusion_config) { | ||
| if (!use_hipkittens_mxfp8 && params.use_bias) { | ||
| GTEST_SKIP() << "MXFP8 GEMM with bias is not supported by hipBLASLt"; | ||
| } | ||
| if (params.use_gelu && !fp8_gelu_fusion_config && !use_hipkittens_mxfp8) { | ||
| GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config"; | ||
| } | ||
| if (params.use_bias && dtype == DType::kFloat16) { | ||
|
|
@@ -412,29 +424,27 @@ void performTest(const TestParams& params) { | |
|
|
||
| if (prop.major == 9 && prop.minor == 5) //gfx950 specific hipblasLt limitations | ||
| { | ||
| if (isFp8Type(dtype)){ | ||
| if (isFp8Type(dtype)) { | ||
| GTEST_SKIP() << "GEMM with float8 output is not supported"; | ||
| } | ||
| if (params.use_gelu && dtype == DType::kBFloat16) { | ||
| if (params.use_gelu && dtype == DType::kBFloat16 && !use_hipkittens_mxfp8) { | ||
| GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; | ||
| } | ||
| if constexpr ((std::is_same<A_Type, bf8>::value || std::is_same<B_Type, bf8>::value) && | ||
| std::is_same<D_Type, fp32>::value) | ||
| { | ||
| //GEMM with bias and fp32 output is not supported with bf8 A/B | ||
| if constexpr ((std::is_same_v<A_Type, bf8> || std::is_same_v<B_Type, bf8>) && | ||
| std::is_same_v<D_Type, fp32>) { | ||
| if (params.use_bias) { | ||
| GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; | ||
| } | ||
| } | ||
| } | ||
| if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations | ||
| else if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations | ||
| { | ||
| #if HIP_VERSION < 70100000 | ||
| if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) { | ||
| GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; | ||
| } | ||
| #endif | ||
| if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) { | ||
| if constexpr (std::is_same_v<D_Type, fp8> && std::is_same_v<Bias_Type, bf16>) { | ||
| if (params.use_bias && !fp8_gelu_fusion_config) { | ||
| GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config"; | ||
| } | ||
|
|
@@ -493,6 +503,11 @@ void performTest(const TestParams& params) { | |
| if ((prop.major == 9 && prop.minor == 5) || prop.major >= 12) { | ||
| workspace_size = 67108864; | ||
| } | ||
| if (use_mxfp8 && !use_hipkittens_mxfp8) { | ||
| workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, | ||
|
wangye805 marked this conversation as resolved.
ipanfilo marked this conversation as resolved.
|
||
| params.transa, params.transb, | ||
| workspace_size); | ||
| } | ||
| #endif | ||
| Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); | ||
|
|
||
|
|
@@ -551,7 +566,7 @@ void performTest(const TestParams& params) { | |
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
|
|
||
| if(params.use_gelu){ | ||
| auto [atol, rtol] = getTestTolerances(gelu_type, false, false); | ||
| auto [atol, rtol] = getTestTolerances(gelu_type, has_fp8, use_mxfp8); | ||
| RefPreGeluOut.to_cpu(); | ||
| compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol); | ||
| } | ||
|
|
@@ -577,10 +592,18 @@ void performDqTest(const TestParams ¶ms) { | |
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
|
|
||
| bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| const bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| const bool use_hipkittens_mxfp8 = !params.force_hipblaslt; | ||
|
|
||
| if (!mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (params.use_bias || params.use_gelu) { | ||
| GTEST_SKIP() << "DqGEMMTestSuite does not yet have reference for bias/gelu epilogues"; | ||
| } | ||
| if (use_hipkittens_mxfp8 && (params.m % 256 || params.n % 256 || params.k % 128 || params.k < 256)) { | ||
| GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256"; | ||
| } | ||
|
|
||
| DType ref_type = dtype; | ||
| TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m}; | ||
|
|
@@ -608,7 +631,9 @@ void performDqTest(const TestParams ¶ms) { | |
| Tensor bias; | ||
| Tensor pre_gelu_out; | ||
|
|
||
| size_t workspace_size = 67108864; | ||
| size_t workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, | ||
| params.transa, params.transb, | ||
| 67108864); // 64 MiB required for hipBLASlt | ||
| Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte); | ||
|
|
||
| //perform FP8 gemm and copy the output results from GPU memory to CPU memory | ||
|
|
@@ -638,6 +663,12 @@ void performDqTest(const TestParams ¶ms) { | |
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| #define MAKE_TEST_PARAMS(P_) \ | ||
| bool force_hipblaslt_ = std::get<5>(GetParam()); \ | ||
| if (force_hipblaslt_) { \ | ||
| setenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8", "1", 1); \ | ||
| } else { \ | ||
| unsetenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8"); \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better make it explicit: set to 0, so it always setenv with just different value |
||
| } \ | ||
| TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \ | ||
| .k = std::get<1>(std::get<0>(GetParam())), \ | ||
| .n = std::get<2>(std::get<0>(GetParam())), \ | ||
|
|
@@ -646,13 +677,14 @@ void performDqTest(const TestParams ¶ms) { | |
| .transa = std::get<3>(GetParam()).first, \ | ||
| .transb = std::get<3>(GetParam()).second, \ | ||
| .scaling_mode = std::get<4>(GetParam()) \ | ||
| ? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \ | ||
| : NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING} | ||
| ? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \ | ||
| : NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING,\ | ||
| .force_hipblaslt = force_hipblaslt_} | ||
|
|
||
| // <m, k, n>, use_bias, use_gelu, Layout, fp8_scalinig | ||
| // <m, k, n>, use_bias, use_gelu, Layout, fp8_scaling, force_hipblaslt | ||
| class GEMMTestSuite | ||
| : public ::testing::TestWithParam< | ||
| std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode>> {}; | ||
| std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode, bool>> {}; | ||
|
|
||
| #define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \ | ||
| TEST_P(GEMMTestSuite, NAME_) { \ | ||
|
|
@@ -713,19 +745,32 @@ static inline auto MKN(const std::tuple<size_t, size_t, size_t>& shape) { | |
| std::to_string(std::get<2>(shape)); | ||
| } | ||
|
|
||
| static std::string GEMMTestName(const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + | ||
| std::to_string(std::get<2>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<4>(info.param) ? "M" : "S") + "x" + | ||
| (std::get<5>(info.param) ? "HB" : "HK"); | ||
| } | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite, | ||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes), | ||
| ::testing::Values(false, true), //use bias | ||
| ::testing::Values(false, true), //use_gelu | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(false, true)), //use mxfp8 | ||
| [](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + | ||
| std::to_string(std::get<2>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<4>(info.param) ? "M" : "S"); | ||
| }); | ||
| ::testing::Values(false), //use mxfp8 | ||
| ::testing::Values(false)), //force hipblaslt | ||
| GEMMTestName); | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, GEMMTestSuite, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With such approach we'll have OperatorTestMXFP8 with Testfp32xfp32xfp32xfp32xfp32. So if you want to use separate suites for more precise tests shaping there rather should be 3 prefixes GEMMTest, FP8GEMMTest and MXFP8GEMMTest where the last two use the same suite. And two sets of MAKE_GEMM_TEST: for FP8 and non-FP8. |
||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes), | ||
| ::testing::Values(false, true), //use bias | ||
| ::testing::Values(false, true), //use_gelu | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(true), //use mxfp8 | ||
| ::testing::Values(false, true)), //force hipblaslt | ||
| GEMMTestName); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| class DqGEMMTestSuite: public GEMMTestSuite {}; | ||
|
|
@@ -741,14 +786,17 @@ class DqGEMMTestSuite: public GEMMTestSuite {}; | |
|
|
||
| MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | ||
| INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, DqGEMMTestSuite, | ||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8), | ||
| ::testing::Values(false), // bias - unused | ||
| ::testing::Values(false), // gelu - unused | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(true)), //use mxfp8 | ||
| ::testing::Values(false), // use bias | ||
| ::testing::Values(false), // use gelu | ||
| ::testing::ValuesIn(kLayouts), // transa,transb | ||
| ::testing::Values(true), // use mxfp8 | ||
| ::testing::Values(false, true)), // force hipblaslt | ||
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<5>(info.param) ? "HB" : "HK"); | ||
| }); | ||
|
|
||
| TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,8 +61,9 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): | |
| pytest.skip( | ||
| f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM." | ||
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same hardcoding 256s... |
||
| if use_bias and not hipkittens_eligible: | ||
| pytest.skip("hipblaslt GEMM does not support MXFP8 with bias.") | ||
| else: | ||
| jax_version = version.parse(jax.__version__) | ||
| if jax_version < version.parse("0.8.2"): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256