Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,166 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N", "MK"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- K tile: 64, M tile: 256, N tile: 128/256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2026
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

if layout == "TN":
# TN GEMM: M=m_splits[i], N=A.rows, K=A.cols
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
elif pad_dim == "MK":
k_val = unaligned_k
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
elif layout == "NN":
# NN GEMM: M=m_splits[i], N=A.cols, K=A.rows
if pad_dim == "K":
gemm_k = unaligned_k
Comment thread
aris134 marked this conversation as resolved.
m_vals = [m_aligned] * z
n_out = n_aligned
elif pad_dim == "M":
gemm_k = k_aligned
m_vals = unaligned_m
n_out = n_aligned
elif pad_dim == "MK":
gemm_k = unaligned_k
m_vals = unaligned_m
n_out = n_aligned
else: # N
gemm_k = k_aligned
m_vals = [m_aligned] * z
n_out = unaligned_n

A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True
else: # NT
# NT GEMM: out[i] = A[i]^T @ B[i], A[i]: (m_i, k), B[i]: (m_i, n), out[i]: (n, k)
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
elif pad_dim == "MK":
k_val = unaligned_k
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

A = list(torch.split(
torch.randn(sum(m_vals), k_val, dtype=dtype, device="cuda"), m_vals
))
B = list(torch.split(
torch.randn(sum(m_vals), n_val, dtype=dtype, device="cuda"), m_vals
))
out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
out_ref = [o.clone() for o in out]
m_splits = m_vals
grad = True
single_output = False

# Reference: individual GEMMs
for i in range(z):
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ else()
gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even).");
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Loading
Loading