Skip to content

HipKittens MXFP8 GEMM Support#566

Open
alextmagro wants to merge 26 commits into
devfrom
hipkittens_mxfp8
Open

HipKittens MXFP8 GEMM Support#566
alextmagro wants to merge 26 commits into
devfrom
hipkittens_mxfp8

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro commented Apr 28, 2026

Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX

Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.

Adds hipKittens header library as a submodule.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/jax/utils.py
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same hardcoding 256s...

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.hip Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
@alextmagro alextmagro requested a review from wangye805 May 5, 2026 20:26
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
if (!use_mxfp8 && params.force_hipblaslt) {
GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8";
}
if (use_mxfp8) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid the skips completely, so split up the test instantiation into non-mxfp8 and mxfp8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean now. We can combine some of the checks to make it easier to read when we are using hipkittens.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a point, they are set to false only

GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias || params.use_gelu) {
if (params.force_hipblaslt) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is skipped below anyway, if add it for future, move it after more generic one

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this and the Dq test name changes are artifacts from my attempt to enable bias and gelu for this test. I a ran into issues with gelu for the non-fp8 GEMM in hipBLASlt, and decided to just focus on the non-Dq tests. I have reverted things.

#include <hip/hip_runtime.h>
#include <cstddef>

enum KittensDType {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it copied from some hipKittent enum? Put comment then

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values come from the NVTE values -- I have added a comment to that extent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And where are they used?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are used L735-750 in mxfp8_gemm.cpp. I have updated those functions to be a bit more defensive, too.

Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/CMakeLists.txt Outdated

return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't rely on torch memory caching?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()

@alextmagro alextmagro requested review from aris134 and ipanfilo May 12, 2026 13:24
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
size_t sa_tr_bytes = align_up((size_t)M * scale_K, 256);
size_t sb_tr_bytes = align_up((size_t)N * scale_K, 256);
size_t sa_pk_bytes = align_up((size_t)k_iters * M * sizeof(uint32_t), 256);
size_t sb_pk_bytes = (size_t)k_iters * N * sizeof(uint32_t);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, can you explain why sb_pk_bytes does not require 256-alignment like the others?

Copy link
Copy Markdown
Contributor Author

@alextmagro alextmagro May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are aligning the end of each variable so that the next address is 256 aligned, not the current one. Since sb_pk_bytes is the last address, we don't need to pad.

Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp Outdated
Comment on lines 15 to 17

namespace transformer_engine {
namespace jax {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: there are a few whitespace-only changes in these files, not sure if they are necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed this, thanks

Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread build_tools/jax.py Outdated
Path(__file__).resolve().parent.parent
/ "3rdparty" / "hipkittens" / "include" / "kittens.cuh"
)
if "gfx950" in rocm_archs and hipkittens_header.exists():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch/JAX extensions do not bear any GPU code but delegate all this to TE core. And kittens are added to TE common too.
Why is this build time setting needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an artifact from when I was running into issues with CI not finding pybinded functions from hipKittens. The issue was elsewhere, and I forgot to remove this. I will remove it, thanks!

@alextmagro alextmagro requested a review from ipanfilo May 14, 2026 17:18
@alextmagro alextmagro added ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels May 14, 2026
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")");
NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")");
NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")");
NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like just spacing change. Please revert if it is the case

transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator,
math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle);
#ifdef USE_HIPKITTENS_GEMM
bool is_mxfp8 = inputA->scaling_mode == NVTE_MXFP8_1D_SCALING
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move it out of ifdef and use in ifs that currently check the same conditon

NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")");
NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")");
NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")");
#ifndef USE_HIPKITTENS_GEMM
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is checked below in else branch of hipkittens conditoon

if (use_hipkittens) {
auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

hipStream_t s = use_service_stream ? ss_ctl.stream : stream;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same like with is_mxfp8, no point of having it defined for one branch only

}

auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8);
size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable

@@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256

if (!use_mxfp8 && params.force_hipblaslt) {
GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8";
}
if (use_mxfp8) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity

#include <hip/hip_runtime.h>
#include <cstddef>

enum KittensDType {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And where are they used?

num_cublas_streams = get_num_compute_streams()


def _hipkittens_workspace_bytes(m: int, n: int, k: int, layout: str) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it check for env to figure out if hipKittens is enabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I had the check for pytorch but forgot it for JAX, that is fixed now.


is_mxfp8 = isinstance(A, MXFP8TensorStorage) or isinstance(B, MXFP8TensorStorage)
if is_mxfp8 and _use_hipkittens():
a_size = A.size() if hasattr(A, "size") and callable(A.size) else A.shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8TensorSttorage has callable size(). What other object could be here that require this condition

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering a scenario where A or B was not MXFP8, but we always have them both as MXFP8 so I think it is ok to simplify the logic

@alextmagro alextmagro requested a review from ipanfilo May 18, 2026 20:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants