From 95f984ce6937fcd5f0c8aaea62c025e4af2b9f81 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 4 May 2026 19:14:20 -0500 Subject: [PATCH 1/3] ck_tile grouped gemm: more padding --- tests/pytorch/test_numerics.py | 122 ++++++++++++++++++ .../ck_grouped_gemm/ck_grouped_gemm_common.h | 12 +- .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 60 +++++++-- 3 files changed, 182 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..f548b36d6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, + reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) +@pytest.mark.parametrize("layout", ["TN", "NN"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize( + "pad_dim", + ["K", "M", "N"], + ids=lambda d: f"pad{d}", +) +def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): + """Test CK grouped GEMM with M, N, or K not aligned to CK tile size. + + CK constraints for bf16/fp16: + - Contiguous dim of A/B must be dword-aligned (even for 2-byte types). + RowMajor: contiguous dim is cols (K for A, N for B). + ColMajor: contiguous dim is rows (M for A, K for B). + - N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256 + - K tile: 64, M tile: 256 + """ + torch.manual_seed(0) + z = 8 + + # Unaligned values per dimension (all satisfy CK vector-load constraints). + # K: even but not multiple of tile (64). Same for all groups. + # M: not multiples of tile (256), varies per group. + # N: multiple of 16 but not multiple of tile (128). + unaligned_k = 2026 + unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180] + unaligned_n = 2032 + + # Aligned defaults. + k_aligned = 2048 + m_aligned = 256 + n_aligned = 2048 + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + + if layout == "TN": + # TN GEMM: M=m_splits[i], N=A.rows, K=A.cols + if pad_dim == "K": + k_val = unaligned_k + m_vals = [m_aligned] * z + n_val = n_aligned + elif pad_dim == "M": + k_val = k_aligned + m_vals = unaligned_m + n_val = n_aligned + else: # N + k_val = k_aligned + m_vals = [m_aligned] * z + n_val = unaligned_n + + A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] + B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals] + total_m = sum(m_vals) + out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_vals)] + m_splits = m_vals + grad = False + single_output = True + else: # NN + # NN GEMM: M=m_splits[i], N=A.cols, K=A.rows + if pad_dim == "K": + gemm_k = unaligned_k + m_vals = [m_aligned] * z + n_out = n_aligned + elif pad_dim == "M": + gemm_k = k_aligned + m_vals = unaligned_m + n_out = n_aligned + else: # N + gemm_k = k_aligned + m_vals = [m_aligned] * z + n_out = unaligned_n + + A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)] + B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals] + total_m = sum(m_vals) + out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_vals)] + m_splits = m_vals + grad = True + single_output = True + + # Reference: individual GEMMs + for i in range(z): + general_gemm( + A[i], + B[i], + dtype, + grad=grad, + accumulate=accumulate, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4): + torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..75746ab8f 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs, if (!Kernel::IsSupportedArgument(kargs)) { NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); + "transA=", ctx.transA, " transB=", ctx.transB, + " accumulate=", ctx.accumulate, " groups=", ctx.group_num, + ". Falling back. " + "CK_Tile constraints for bf16/fp16: " + "contiguous dim of A and B must be dword-aligned (even), " + "N must be multiple of 16 (GetVectorSizeC)."); + for (size_t i = 0; i < descs.size(); ++i) { + NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K, + " stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B, + " stride_E=", descs[i].stride_E); + } return false; } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..1f66cdf57 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; +template +struct WithPadding : Base { + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr bool kPadK = PadK_; }; template , \ accum_option>; \ runner = std::make_unique(); \ }) @@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, const ck_tile::stream_config s{ctx.stream}; std::unique_ptr runner = nullptr; + // Check M and K alignment across all groups. + // All tile configs share the same M_Tile (256) and K_Tile (64). + constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; + constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile; + + bool need_m_pad = false; + bool need_k_pad = false; + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + int64_t Ad0 = 0, Ad1 = 0; + if (get_flat_2d_dims(*A_te, Ad0, Ad1)) { + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + + if (M % M_Tile != 0) + need_m_pad = true; + if (K % K_Tile != 0) + need_k_pad = true; + if (need_m_pad && need_k_pad) + break; + } + } + + // CK tile kernel produces incorrect results with kPadK + ColMajor B. + // Fall back to cuBLAS for this combination. + if (need_k_pad && ctx.transB) { + return false; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { using ALayout = std::conditional_t; @@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); - } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); - } + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); }); }); }); From 225c3dca9e86143214bd9ffdd640c03416e41e5a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 15 May 2026 20:30:22 +0000 Subject: [PATCH 2/3] address review comments --- tests/pytorch/test_numerics.py | 15 ++++++++++++--- .../gemm/ck_grouped_gemm/ck_grouped_gemm_common.h | 3 +-- .../gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 514838c2b..847c866da 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3080,11 +3080,13 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) +# NT is excluded: transB=true means ColMajor B, and CK produces incorrect +# results with kPadK + ColMajor B (the dispatch falls back to cuBLAS). @pytest.mark.parametrize("layout", ["TN", "NN"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize( "pad_dim", - ["K", "M", "N"], + ["K", "M", "N", "MK"], ids=lambda d: f"pad{d}", ) def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): @@ -3094,8 +3096,7 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): - Contiguous dim of A/B must be dword-aligned (even for 2-byte types). RowMajor: contiguous dim is cols (K for A, N for B). ColMajor: contiguous dim is rows (M for A, K for B). - - N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256 - - K tile: 64, M tile: 256 + - K tile: 64, M tile: 256, N tile: 128/256 """ torch.manual_seed(0) z = 8 @@ -3125,6 +3126,10 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): k_val = k_aligned m_vals = unaligned_m n_val = n_aligned + elif pad_dim == "MK": + k_val = unaligned_k + m_vals = unaligned_m + n_val = n_aligned else: # N k_val = k_aligned m_vals = [m_aligned] * z @@ -3148,6 +3153,10 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): gemm_k = k_aligned m_vals = unaligned_m n_out = n_aligned + elif pad_dim == "MK": + gemm_k = unaligned_k + m_vals = unaligned_m + n_out = n_aligned else: # N gemm_k = k_aligned m_vals = [m_aligned] * z diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 75746ab8f..eb1c46f93 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -144,8 +144,7 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs, " accumulate=", ctx.accumulate, " groups=", ctx.group_num, ". Falling back. " "CK_Tile constraints for bf16/fp16: " - "contiguous dim of A and B must be dword-aligned (even), " - "N must be multiple of 16 (GetVectorSizeC)."); + "contiguous dim of A and B must be dword-aligned (even)."); for (size_t i = 0; i < descs.size(); ++i) { NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K, " stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 1f66cdf57..6b7976557 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -244,7 +244,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, } } - // CK tile kernel produces incorrect results with kPadK + ColMajor B. + // FIXME: CK tile kernel produces incorrect results with kPadK + ColMajor B. // Fall back to cuBLAS for this combination. if (need_k_pad && ctx.transB) { return false; From 29390173c5b0cca9e5a040276f166ad5180b3aa7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 15 May 2026 22:16:47 +0000 Subject: [PATCH 3/3] NT workaround, split, address review comments --- tests/pytorch/test_numerics.py | 37 ++- transformer_engine/common/CMakeLists.txt | 4 + .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 278 +++--------------- .../ck_grouped_gemm_fp16_impl.h | 250 ++++++++++++++++ .../ck_grouped_gemm_fp16_nn.cpp | 52 ++++ .../ck_grouped_gemm_fp16_nt.cpp | 52 ++++ .../ck_grouped_gemm_fp16_tn.cpp | 52 ++++ .../ck_grouped_gemm_fp16_tt.cpp | 52 ++++ 8 files changed, 530 insertions(+), 247 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 847c866da..0811439e4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3080,9 +3080,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) -# NT is excluded: transB=true means ColMajor B, and CK produces incorrect -# results with kPadK + ColMajor B (the dispatch falls back to cuBLAS). -@pytest.mark.parametrize("layout", ["TN", "NN"]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize( "pad_dim", @@ -3143,7 +3141,7 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): m_splits = m_vals grad = False single_output = True - else: # NN + elif layout == "NN": # NN GEMM: M=m_splits[i], N=A.cols, K=A.rows if pad_dim == "K": gemm_k = unaligned_k @@ -3170,6 +3168,36 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): m_splits = m_vals grad = True single_output = True + else: # NT + # NT GEMM: out[i] = A[i]^T @ B[i], A[i]: (m_i, k), B[i]: (m_i, n), out[i]: (n, k) + if pad_dim == "K": + k_val = unaligned_k + m_vals = [m_aligned] * z + n_val = n_aligned + elif pad_dim == "M": + k_val = k_aligned + m_vals = unaligned_m + n_val = n_aligned + elif pad_dim == "MK": + k_val = unaligned_k + m_vals = unaligned_m + n_val = n_aligned + else: # N + k_val = k_aligned + m_vals = [m_aligned] * z + n_val = unaligned_n + + A = list(torch.split( + torch.randn(sum(m_vals), k_val, dtype=dtype, device="cuda"), m_vals + )) + B = list(torch.split( + torch.randn(sum(m_vals), n_val, dtype=dtype, device="cuda"), m_vals + )) + out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] + out_ref = [o.clone() for o in out] + m_splits = m_vals + grad = True + single_output = False # Reference: individual GEMMs for i in range(z): @@ -3206,6 +3234,7 @@ def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 774065fca..790159ec1 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -259,6 +259,10 @@ else() gemm/ck_grouped_gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 6b7976557..b01c4abe3 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -6,219 +6,15 @@ #include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp16.h" +#include "ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { namespace grouped_gemm { -// ------------------------- -// Tile configs: FP16/BF16 -// ------------------------- - -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -template -struct WithPadding : Base { - static constexpr bool kPadM = PadM_; - static constexpr bool kPadN = PadN_; - static constexpr bool kPadK = PadK_; -}; - -template -class GroupedGemmRunner : public RunnerInterface { - public: - using GemmShape = GroupedGemmShape; - using Partitioner = GroupedGemmPartitioner; - - using UniversalTraits = - ck_tile::PersistentTileGemmUniversalTraits; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = - ck_tile::UniversalGemmPipelineProblem; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using ET = EpilogueTraits; - - using Epilogue = - ck_tile::CShuffleEpilogue>; - - using Kernel = ck_tile::GroupedGemmKernel; - - // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. - using HostArgs = std::conditional_t, - ck_tile::GroupedGemmHostArgs<0>>; - - public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - if (!has_sufficient_workspace(ctx)) { - return {}; - } - - std::vector descs; - descs.reserve(ctx.group_num); - - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); - } - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - } - - const ck_tile::index_t stride_A = static_cast(Ad1); - const ck_tile::index_t stride_B = static_cast(Bd1); - const ck_tile::index_t stride_E = static_cast(Dd1); - - if constexpr (Accumulate) { - descs.emplace_back(a.dptr, - b.dptr, - std::array{d.dptr}, - d.dptr, - 1, - M, - N, - K, - stride_A, - stride_B, - std::array{stride_E}, - stride_E); - } else { - descs.emplace_back(a.dptr, - b.dptr, - std::array{}, - d.dptr, - 1, - M, - N, - K, - stride_A, - stride_B, - std::array{}, - stride_E); - } - } - - return descs; - } - - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - if (descs.empty()) { - return false; - } - return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); - } -}; - -#define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \ - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ - using Runner = GroupedGemmRunner, \ - accum_option>; \ - runner = std::make_unique(); \ - }) - bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { - const ck_tile::stream_config s{ctx.stream}; - std::unique_ptr runner = nullptr; - // Check M and K alignment across all groups. // All tile configs share the same M_Tile (256) and K_Tile (64). constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; @@ -245,49 +41,45 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, } // FIXME: CK tile kernel produces incorrect results with kPadK + ColMajor B. - // Fall back to cuBLAS for this combination. + // Workaround: use B's column-wise storage buffer (RowMajor) with transB=false, + // which preserves the same logical GEMM while avoiding the buggy path. + // Fall back to cuBLAS only if the column-wise buffer is unavailable. if (need_k_pad && ctx.transB) { - return false; + // Check all B tensors have columnwise_data available. + bool all_have_columnwise = true; + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + if (!B_te->has_columnwise_data()) { + all_have_columnwise = false; + break; + } + } + if (!all_have_columnwise) { + return false; + } + // Dispatch with B's columnwise buffer as RowMajor (transB=false). + GroupedGemmRunContext ctx_nn = ctx; + ctx_nn.transB = false; + ctx_nn.use_b_columnwise_data = true; + if (!ctx_nn.transA) { + return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_nn); + } else { + return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_nn); + } } - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; - using CLayout = RowMajor; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); - } else { - MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); - } - }); - }); - }); - }); - }); - }); - - if (!runner) { - return false; + // Dispatch to per-layout translation unit. + if (!ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (!ctx.transA && ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else { + return ck_tile_grouped_gemm_fp16_dispatch_tt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } - - return runner->run(s, ctx); } -#undef MAKE_RUNNER - } // namespace grouped_gemm } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h new file mode 100644 index 000000000..b331e18d8 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -0,0 +1,250 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include "ck_grouped_gemm_common.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +template +struct WithPadding : Base { + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr bool kPadK = PadK_; +}; + +template +class GroupedGemmRunner : public RunnerInterface { + public: + using GemmShape = GroupedGemmShape; + using Partitioner = GroupedGemmPartitioner; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = + ck_tile::UniversalGemmPipelineProblem; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using ET = EpilogueTraits; + + using Epilogue = + ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedGemmKernel; + + // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. + using HostArgs = std::conditional_t, + ck_tile::GroupedGemmHostArgs<0>>; + + public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + if (!has_sufficient_workspace(ctx)) { + return {}; + } + + std::vector descs; + descs.reserve(ctx.group_num); + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + + const transformer_engine::SimpleTensor* b_src = nullptr; + if (ctx.use_b_columnwise_data) { + b_src = &B_te->columnwise_data; + } else { + b_src = &B_te->data; + } + const auto& b = *b_src; + + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for A in group ", i); + } + if (ctx.use_b_columnwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for B in group ", i); + } + } + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D in group ", i); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + if constexpr (Accumulate) { + descs.emplace_back(a.dptr, + b.dptr, + std::array{d.dptr}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{stride_E}, + stride_E); + } else { + descs.emplace_back(a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + } + + return descs; + } + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + if (descs.empty()) { + return false; + } + return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + } +}; + +#define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \ + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ + using Runner = GroupedGemmRunner, \ + accum_option>; \ + runner = std::make_unique(); \ + }) + +// Per-layout dispatch function signature. +// Each layout file (NN, NT, TN, TT) implements one of these. +bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx); +bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx); +bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx); +bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx); + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp new file mode 100644 index 000000000..963e19f1e --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner = nullptr; + + using ALayout = RowMajor; + using BLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); + }); + }); + + if (!runner) { + return false; + } + return runner->run(s, ctx); +} + +#undef MAKE_RUNNER + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp new file mode 100644 index 000000000..91b55a9bb --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner = nullptr; + + using ALayout = RowMajor; + using BLayout = ColMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); + }); + }); + + if (!runner) { + return false; + } + return runner->run(s, ctx); +} + +#undef MAKE_RUNNER + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp new file mode 100644 index 000000000..51671018f --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner = nullptr; + + using ALayout = ColMajor; + using BLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); + }); + }); + + if (!runner) { + return false; + } + return runner->run(s, ctx); +} + +#undef MAKE_RUNNER + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp new file mode 100644 index 000000000..1f558ebfc --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, + bool need_m_pad, bool need_k_pad, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner = nullptr; + + using ALayout = ColMajor; + using BLayout = ColMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); + }); + }); + + if (!runner) { + return false; + } + return runner->run(s, ctx); +} + +#undef MAKE_RUNNER + +} // namespace grouped_gemm +} // namespace transformer_engine