Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "rocm_libraries"]
path = 3rdparty/rocm_libraries
url = https://github.com/ROCm/rocm-libraries.git
branch = users/jia/ck/fix_grouped_gemm_quant_mxtype
1 change: 1 addition & 0 deletions 3rdparty/rocm_libraries
Submodule rocm_libraries added at 66b1d1
2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ set_property(
PROPERTY
COMPILE_OPTIONS "-g0;-dopt=on")
else()
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel)
target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include)
endif() #USE_CUDA

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
// FP8 special handling.
//
// A_use/B_use and transA_use/transB_use have already gone through the
// upstream-style grouped GEMM normalization above. This block only rewrites
// that normalized presentation into the CK FP8 preferred NT presentation by selecting
// `columnwise_data` when needed.
// upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is
// compiled only for the preferred NT presentation:
//
// CK FP8 target presentation:
// A_use: N
// B_use: T
// transA_use = false
// transB_use = true
//
// The outer condition checks whether this NT presentation is possible:
// - A_use is already N, or can be made N using columnwise_data
// - B_use is already T, or can be made T using columnwise_data
// This block rewrites the normalized presentation into that NT form by
// selecting columnwise_data when needed. If the required columnwise_data view
// is unavailable, this CK FP8 backend cannot represent the GEMM in its
// supported layout form, so we fall back instead of compiling/running an
// unsupported layout variant.
//
// Then each operand is rewritten independently only if needed:
// Rewrite cases:
// NN -> rewrite B only
// TN -> rewrite A and B
// NT -> already in target form
Expand All @@ -81,16 +81,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
const bool has_a_col = A0_te->has_columnwise_data();
const bool has_b_col = B0_te->has_columnwise_data();

if ((!transA_use || has_a_col) && (transB_use || has_b_col)) {
if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}
const bool can_make_a_nt = !transA_use || has_a_col;
const bool can_make_b_nt = transB_use || has_b_col;

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
if (!can_make_a_nt || !can_make_b_nt) {
NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. "
"Missing required columnwise_data for layout rewrite; falling back.");
return false;
}

if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include <hip/hip_runtime.h>
#include "common/util/cuda_runtime.h"

#include <array>
#include <type_traits>
Expand Down Expand Up @@ -70,6 +71,28 @@ static inline const transformer_engine::SimpleTensor& scale_inv_view(const trans
return t.scale_inv;
}

enum class GPUArch {
GFX942,
GFX950,
GFX1250,
UNKNOWN
};

static inline GPUArch detect_gpu_arch() {
int arch = cuda::sm_arch(0);

if (arch == 94) {
return GPUArch::GFX942;
}
if (arch == 95) {
return GPUArch::GFX950;
}
if (arch == 125 || arch == 1250) {
return GPUArch::GFX1250;
}
return GPUArch::UNKNOWN;
}

struct GroupedGemmRunContext {
const NVTETensor* A = nullptr;
const NVTETensor* B = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace grouped_gemm {
// Tile configs: FP16/BF16
// -------------------------

struct TileCfg_256x256x64 {
struct TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
Expand All @@ -37,14 +37,37 @@ struct TileCfg_256x256x64 {
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x128x64 : TileCfg_256x256x64 {
struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA {
static constexpr bool kPadN = true;
};

struct TileCfg_256x256x64_WMMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

template <typename AType,
typename BType,
typename CType,
Expand Down Expand Up @@ -209,7 +232,26 @@ class GroupedGemmRunner : public RunnerInterface {
runner = std::make_unique<Runner>(); \
})

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
template <GPUArch Arch>
struct FP16TileCfg;

template <>
struct FP16TileCfg<GPUArch::GFX942> {
using type = TileCfg_256x256x64_MFMA;
};

template <>
struct FP16TileCfg<GPUArch::GFX950> {
using type = TileCfg_256x256x64_MFMA;
};

template <>
struct FP16TileCfg<GPUArch::GFX1250> {
using type = TileCfg_256x256x64_WMMA;
};

template <GPUArch Arch>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it need template over reguler if-else or switch-case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The template is needed because the arch selection affects CK kernel template instantiation, not just runtime control flow. GPUArch must be a compile-time value so if constexpr can prune unsupported tile/kernel combinations for a given architecture. In this case, it prevents the MFMA configs from being instantiated for gfx1250.

bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
Expand All @@ -229,13 +271,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);

if constexpr (Arch == GPUArch::GFX1250) {
MAKE_RUNNER(TileCfg_256x256x64_WMMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64_MFMA);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64_MFMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding);
}
}
});
});
Expand All @@ -249,6 +295,23 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
return runner->run(s, ctx);
}

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
switch (detect_gpu_arch()) {
case GPUArch::GFX942:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX950:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX1250:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx);
default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}");
return false;
}
}

#undef MAKE_RUNNER

} // namespace grouped_gemm
Expand Down
Loading
Loading