HipKittens MXFP8 GEMM Support#566
Conversation
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) |
| if (!use_mxfp8 && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8"; | ||
| } | ||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8
There was a problem hiding this comment.
I wanted to avoid the skips completely, so split up the test instantiation into non-mxfp8 and mxfp8.
There was a problem hiding this comment.
Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity
There was a problem hiding this comment.
I see what you mean now. We can combine some of the checks to make it easier to read when we are using hipkittens.
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + |
There was a problem hiding this comment.
What is a point, they are set to false only
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (params.use_bias || params.use_gelu) { | ||
| if (params.force_hipblaslt) { |
There was a problem hiding this comment.
It is skipped below anyway, if add it for future, move it after more generic one
There was a problem hiding this comment.
Sorry, this and the Dq test name changes are artifacts from my attempt to enable bias and gelu for this test. I a ran into issues with gelu for the non-fp8 GEMM in hipBLASlt, and decided to just focus on the non-Dq tests. I have reverted things.
| #include <hip/hip_runtime.h> | ||
| #include <cstddef> | ||
|
|
||
| enum KittensDType { |
There was a problem hiding this comment.
Is it copied from some hipKittent enum? Put comment then
There was a problem hiding this comment.
These values come from the NVTE values -- I have added a comment to that extent.
There was a problem hiding this comment.
And where are they used?
There was a problem hiding this comment.
They are used L735-750 in mxfp8_gemm.cpp. I have updated those functions to be a bit more defensive, too.
|
|
||
| return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) |
There was a problem hiding this comment.
Why we don't rely on torch memory caching?
There was a problem hiding this comment.
I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()
| size_t sa_tr_bytes = align_up((size_t)M * scale_K, 256); | ||
| size_t sb_tr_bytes = align_up((size_t)N * scale_K, 256); | ||
| size_t sa_pk_bytes = align_up((size_t)k_iters * M * sizeof(uint32_t), 256); | ||
| size_t sb_pk_bytes = (size_t)k_iters * N * sizeof(uint32_t); |
There was a problem hiding this comment.
For my own understanding, can you explain why sb_pk_bytes does not require 256-alignment like the others?
There was a problem hiding this comment.
Here, we are aligning the end of each variable so that the next address is 256 aligned, not the current one. Since sb_pk_bytes is the last address, we don't need to pad.
|
|
||
| namespace transformer_engine { | ||
| namespace jax { |
There was a problem hiding this comment.
Nit: there are a few whitespace-only changes in these files, not sure if they are necessary.
There was a problem hiding this comment.
I have removed this, thanks
| Path(__file__).resolve().parent.parent | ||
| / "3rdparty" / "hipkittens" / "include" / "kittens.cuh" | ||
| ) | ||
| if "gfx950" in rocm_archs and hipkittens_header.exists(): |
There was a problem hiding this comment.
Pytorch/JAX extensions do not bear any GPU code but delegate all this to TE core. And kittens are added to TE common too.
Why is this build time setting needed?
There was a problem hiding this comment.
This is an artifact from when I was running into issues with CI not finding pybinded functions from hipKittens. The issue was elsewhere, and I forgot to remove this. I will remove it, thanks!
| NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); | ||
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); |
There was a problem hiding this comment.
It looks like just spacing change. Please revert if it is the case
| transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, | ||
| math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle); | ||
| #ifdef USE_HIPKITTENS_GEMM | ||
| bool is_mxfp8 = inputA->scaling_mode == NVTE_MXFP8_1D_SCALING |
There was a problem hiding this comment.
Move it out of ifdef and use in ifs that currently check the same conditon
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); | ||
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| #ifndef USE_HIPKITTENS_GEMM |
There was a problem hiding this comment.
It is checked below in else branch of hipkittens conditoon
| if (use_hipkittens) { | ||
| auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); | ||
|
|
||
| hipStream_t s = use_service_stream ? ss_ctl.stream : stream; |
There was a problem hiding this comment.
the same like with is_mxfp8, no point of having it defined for one branch only
| } | ||
|
|
||
| auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); | ||
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; |
| @@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | |||
|
|
|||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |||
There was a problem hiding this comment.
If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency
| @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |||
|
|
|||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | |||
There was a problem hiding this comment.
test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?
There was a problem hiding this comment.
Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256
| if (!use_mxfp8 && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8"; | ||
| } | ||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity
| #include <hip/hip_runtime.h> | ||
| #include <cstddef> | ||
|
|
||
| enum KittensDType { |
There was a problem hiding this comment.
And where are they used?
| num_cublas_streams = get_num_compute_streams() | ||
|
|
||
|
|
||
| def _hipkittens_workspace_bytes(m: int, n: int, k: int, layout: str) -> int: |
There was a problem hiding this comment.
Should it check for env to figure out if hipKittens is enabled?
There was a problem hiding this comment.
Yes, I had the check for pytorch but forgot it for JAX, that is fixed now.
|
|
||
| is_mxfp8 = isinstance(A, MXFP8TensorStorage) or isinstance(B, MXFP8TensorStorage) | ||
| if is_mxfp8 and _use_hipkittens(): | ||
| a_size = A.size() if hasattr(A, "size") and callable(A.size) else A.shape |
There was a problem hiding this comment.
MXFP8TensorSttorage has callable size(). What other object could be here that require this condition
There was a problem hiding this comment.
I was considering a scenario where A or B was not MXFP8, but we always have them both as MXFP8 so I think it is ok to simplify the logic
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.