Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f9d5ce2
HipKittens MXFP8 GEMM Support
alextmagro Apr 28, 2026
aac5860
Update HipKittens branch after upstream MXFP8 merge
alextmagro May 5, 2026
c917ed0
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8
alextmagro May 5, 2026
3a91321
Update HipKittens commit and address PR comments
alextmagro May 5, 2026
cc719fe
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8 with
alextmagro May 5, 2026
fcda154
Resolve conflicts, ensure fp4 workspace changes are harmonious
alextmagro May 5, 2026
70fba6d
min workspace size guaranteed
alextmagro May 5, 2026
455002e
add hipkittens to wheels
alextmagro May 5, 2026
ba60ef5
fix issue with gfx942 for unified build
alextmagro May 6, 2026
f72b7b8
Cleanup and workspace changes
alextmagro May 12, 2026
731640a
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8
alextmagro May 12, 2026
1960c06
fix jax import issue
alextmagro May 12, 2026
320152e
Fix autotuning bug
alextmagro May 12, 2026
a280cf7
fix pytorch import
alextmagro May 13, 2026
2a27902
Revert workspace changes to avoid sizing race condition
alextmagro May 13, 2026
3d7aaf9
Revert C++ workspace change to Python
alextmagro May 13, 2026
824841d
Cleanup style and build_tools relics
alextmagro May 14, 2026
f66f77c
Fix whitespaces and comment issues
alextmagro May 14, 2026
0b6e702
Kernel optimizations
alextmagro May 18, 2026
816c752
Add use_hipkittens_mxfp8 bool to test_cublaslt_gemm.cu
alextmagro May 18, 2026
aaa88d7
rocm_gemm.cu cleanup
alextmagro May 18, 2026
e2203c0
Add env check to jax file
alextmagro May 18, 2026
7648594
Simplify Workspace Check
alextmagro May 18, 2026
03f675b
Revert kernel optimizations
alextmagro May 18, 2026
f852c22
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8
alextmagro May 19, 2026
3b307bb
Readd dropped test code
alextmagro May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/rocm-wheels-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ jobs:
3rdparty/aotriton \
3rdparty/aiter \
3rdparty/QoLA \
3rdparty/hipify_torch
3rdparty/hipify_torch \
3rdparty/hipkittens

- name: Derive Docker image tag
id: set-tag
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "3rdparty/hipkittens"]
path = 3rdparty/hipkittens
url = https://github.com/HazyResearch/HipKittens.git
1 change: 1 addition & 0 deletions 3rdparty/hipkittens
Submodule hipkittens added at 778274
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ run_test_config(){
run 1 test_jit.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 test_multi_tensor.py
run 1 test_numerics.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa_lbl "mxfp8" 1 test_numerics.py -k "recipe0 and 126m and not grouped"
run_default_fa 1 test_permutation.py
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
Expand Down
144 changes: 96 additions & 48 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256

{32, 128, 16},
{256, 256, 256},
{768, 3072, 4096},
{4096, 16384, 4096},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -168,6 +170,20 @@ __global__ void compute_ref_kernel(
}


constexpr size_t kMXFP8GroupSize = 32;
constexpr size_t kKTileSize = 128;

static size_t compute_mxfp8_workspace_size(size_t m, size_t k, size_t n, bool transa, bool transb, size_t base_size) {
size_t k_iters = k / kKTileSize;
size_t scale_k = k / kMXFP8GroupSize;
size_t sa_pk = round_up_to_nearest_multiple(k_iters * m * 4, 256);
size_t sb_pk = k_iters * n * 4;
size_t needed = round_up_to_nearest_multiple(sa_pk, 256) + sb_pk;
if (!transa) needed += round_up_to_nearest_multiple(m * k, 256) + round_up_to_nearest_multiple(m * scale_k, 256) + round_up_to_nearest_multiple(sa_pk, 256);
if (transb) needed += round_up_to_nearest_multiple(n * k, 256) + round_up_to_nearest_multiple(n * scale_k, 256) + round_up_to_nearest_multiple(sb_pk, 256);
return std::max(base_size, needed);
}

struct TestParams {
size_t m;
size_t k;
Expand All @@ -177,6 +193,7 @@ struct TestParams {
bool transa;
bool transb;
NVTEScalingMode scaling_mode;
bool force_hipblaslt;
};


Expand Down Expand Up @@ -313,7 +330,7 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
// relax for certain FP8 gemm with hipblaslt
if (use_mxfp8) {
atol = 5e-4;
rtol = std::max(rtol, 1e-3);
rtol = std::max(rtol, 5e-3);
}
else if (use_fp8) {
atol = 1e-3;
Expand All @@ -340,9 +357,9 @@ void performTest(const TestParams& params) {

const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype);
const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
const bool use_hipkittens_mxfp8 = use_mxfp8 && !params.force_hipblaslt;

if (use_mxfp8)
{
if (use_mxfp8) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid the skips completely, so split up the test instantiation into non-mxfp8 and mxfp8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean now. We can combine some of the checks to make it easier to read when we are using hipkittens.

if (!has_fp8) {
GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types";
}
Expand All @@ -352,6 +369,9 @@ void performTest(const TestParams& params) {
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
}
if (use_hipkittens_mxfp8 && (params.m % 256 || params.n % 256 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256";
}
}

cudaDeviceProp prop;
Expand Down Expand Up @@ -383,26 +403,18 @@ void performTest(const TestParams& params) {

if (has_fp8)
{
bool fp8_supported = (prop.major == 9 && prop.minor >= 4) || prop.major >= 12;
const bool fp8_supported = (prop.major == 9 && prop.minor >= 4) || prop.major >= 12;
if (!fp8_supported) {
GTEST_SKIP() << "FP8 is not supported in current config";
}

if (use_mxfp8)
{
bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (isFp8Type(dtype)){
GTEST_SKIP() << "MXFP8 with float8 output is not supported";
}
if (params.use_bias) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported";
}
const bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (use_mxfp8 && !mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

if (params.use_gelu && !fp8_gelu_fusion_config) {
if (!use_hipkittens_mxfp8 && params.use_bias) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported by hipBLASLt";
}
if (params.use_gelu && !fp8_gelu_fusion_config && !use_hipkittens_mxfp8) {
GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config";
}
if (params.use_bias && dtype == DType::kFloat16) {
Expand All @@ -412,29 +424,27 @@ void performTest(const TestParams& params) {

if (prop.major == 9 && prop.minor == 5) //gfx950 specific hipblasLt limitations
{
if (isFp8Type(dtype)){
if (isFp8Type(dtype)) {
GTEST_SKIP() << "GEMM with float8 output is not supported";
}
if (params.use_gelu && dtype == DType::kBFloat16) {
if (params.use_gelu && dtype == DType::kBFloat16 && !use_hipkittens_mxfp8) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
if constexpr ((std::is_same<A_Type, bf8>::value || std::is_same<B_Type, bf8>::value) &&
std::is_same<D_Type, fp32>::value)
{
//GEMM with bias and fp32 output is not supported with bf8 A/B
if constexpr ((std::is_same_v<A_Type, bf8> || std::is_same_v<B_Type, bf8>) &&
std::is_same_v<D_Type, fp32>) {
if (params.use_bias) {
GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config";
}
}
}
if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
else if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
{
#if HIP_VERSION < 70100000
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
#endif
if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) {
if constexpr (std::is_same_v<D_Type, fp8> && std::is_same_v<Bias_Type, bf16>) {
if (params.use_bias && !fp8_gelu_fusion_config) {
GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config";
}
Expand Down Expand Up @@ -493,6 +503,11 @@ void performTest(const TestParams& params) {
if ((prop.major == 9 && prop.minor == 5) || prop.major >= 12) {
workspace_size = 67108864;
}
if (use_mxfp8 && !use_hipkittens_mxfp8) {
workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
Comment thread
wangye805 marked this conversation as resolved.
Comment thread
ipanfilo marked this conversation as resolved.
params.transa, params.transb,
workspace_size);
}
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

Expand Down Expand Up @@ -551,7 +566,7 @@ void performTest(const TestParams& params) {
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);

if(params.use_gelu){
auto [atol, rtol] = getTestTolerances(gelu_type, false, false);
auto [atol, rtol] = getTestTolerances(gelu_type, has_fp8, use_mxfp8);
RefPreGeluOut.to_cpu();
compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol);
}
Expand All @@ -577,10 +592,18 @@ void performDqTest(const TestParams &params) {
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
const bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
const bool use_hipkittens_mxfp8 = !params.force_hipblaslt;

if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias || params.use_gelu) {
GTEST_SKIP() << "DqGEMMTestSuite does not yet have reference for bias/gelu epilogues";
}
if (use_hipkittens_mxfp8 && (params.m % 256 || params.n % 256 || params.k % 128 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256";
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
Expand Down Expand Up @@ -608,7 +631,9 @@ void performDqTest(const TestParams &params) {
Tensor bias;
Tensor pre_gelu_out;

size_t workspace_size = 67108864;
size_t workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
params.transa, params.transb,
67108864); // 64 MiB required for hipBLASlt
Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte);

//perform FP8 gemm and copy the output results from GPU memory to CPU memory
Expand Down Expand Up @@ -638,6 +663,12 @@ void performDqTest(const TestParams &params) {
#endif // __HIP_PLATFORM_AMD__

#define MAKE_TEST_PARAMS(P_) \
bool force_hipblaslt_ = std::get<5>(GetParam()); \
if (force_hipblaslt_) { \
setenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8", "1", 1); \
} else { \
unsetenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8"); \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better make it explicit: set to 0, so it always setenv with just different value

} \
TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \
.k = std::get<1>(std::get<0>(GetParam())), \
.n = std::get<2>(std::get<0>(GetParam())), \
Expand All @@ -646,13 +677,14 @@ void performDqTest(const TestParams &params) {
.transa = std::get<3>(GetParam()).first, \
.transb = std::get<3>(GetParam()).second, \
.scaling_mode = std::get<4>(GetParam()) \
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING}
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING,\
.force_hipblaslt = force_hipblaslt_}

// <m, k, n>, use_bias, use_gelu, Layout, fp8_scalinig
// <m, k, n>, use_bias, use_gelu, Layout, fp8_scaling, force_hipblaslt
class GEMMTestSuite
: public ::testing::TestWithParam<
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode>> {};
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode, bool>> {};

#define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \
TEST_P(GEMMTestSuite, NAME_) { \
Expand Down Expand Up @@ -713,19 +745,32 @@ static inline auto MKN(const std::tuple<size_t, size_t, size_t>& shape) {
std::to_string(std::get<2>(shape));
}

static std::string GEMMTestName(const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<4>(info.param) ? "M" : "S") + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
}

INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes),
::testing::Values(false, true), //use bias
::testing::Values(false, true), //use_gelu
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(false, true)), //use mxfp8
[](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<4>(info.param) ? "M" : "S");
});
::testing::Values(false), //use mxfp8
::testing::Values(false)), //force hipblaslt
GEMMTestName);

INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, GEMMTestSuite,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With such approach we'll have OperatorTestMXFP8 with Testfp32xfp32xfp32xfp32xfp32. So if you want to use separate suites for more precise tests shaping there rather should be 3 prefixes GEMMTest, FP8GEMMTest and MXFP8GEMMTest where the last two use the same suite. And two sets of MAKE_GEMM_TEST: for FP8 and non-FP8.
Alternatively, the same prefix can be reused and different test suites made with 2 MAKE_GEMM_TEST: for FP8 and non-FP8 where the former has 2 TEST_P: for MXFP8 suite and non-MXFP8 one.

::testing::Combine(::testing::ValuesIn(test_case_sizes),
::testing::Values(false, true), //use bias
::testing::Values(false, true), //use_gelu
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(true), //use mxfp8
::testing::Values(false, true)), //force hipblaslt
GEMMTestName);

#ifdef __HIP_PLATFORM_AMD__
class DqGEMMTestSuite: public GEMMTestSuite {};
Expand All @@ -741,14 +786,17 @@ class DqGEMMTestSuite: public GEMMTestSuite {};

MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, DqGEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8),
::testing::Values(false), // bias - unused
::testing::Values(false), // gelu - unused
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(true)), //use mxfp8
::testing::Values(false), // use bias
::testing::Values(false), // use gelu
::testing::ValuesIn(kLayouts), // transa,transb
::testing::Values(true), // use mxfp8
::testing::Values(false, true)), // force hipblaslt
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
});

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]
TEST_SHAPES = [(64, 32, 64)]
if is_hip_extension():
TEST_SHAPES += [(64, 64, 128), (128, 256, 256)]
TEST_SHAPES += [(64, 64, 128), (128, 256, 256), (256, 256, 256)]
jnp_float8_e4m3_type = get_jnp_float8_e4m3_type()
jnp_float8_e5m2_type = get_jnp_float8_e5m2_type()

Expand Down
5 changes: 3 additions & 2 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False):
pytest.skip(
f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM."
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same hardcoding 256s...

if use_bias and not hipkittens_eligible:
pytest.skip("hipblaslt GEMM does not support MXFP8 with bias.")
else:
jax_version = version.parse(jax.__version__)
if jax_version < version.parse("0.8.2"):
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cmake_minimum_required(VERSION 3.21)
option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
option(USE_HIPKITTENS_GEMM "Use HipKittens MXFP8 GEMM kernels" ON)
set(USE_CUDA OFF)

if (USE_ROCM)
Expand Down Expand Up @@ -453,6 +454,10 @@ else()
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
endif()

if(USE_HIPKITTENS_GEMM)
add_subdirectory(gemm/kittens ${CMAKE_CURRENT_BINARY_DIR}/kittens)
endif()

find_package(hip)
list(APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64)
find_package(hiprtc)
Expand All @@ -467,6 +472,10 @@ else()
target_compile_definitions(transformer_engine PUBLIC USE_FUSED_ATTN_CK)
list(APPEND transformer_engine_LINKER_LIBS ck_fused_attn)
endif()
if(USE_HIPKITTENS_GEMM)
target_compile_definitions(transformer_engine PUBLIC USE_HIPKITTENS_GEMM)
list(APPEND transformer_engine_LINKER_LIBS kittens_gemm)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()

Expand Down
Loading
Loading