Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2999,7 +2999,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm)
def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
@pytest.mark.parametrize("fp32_output", [False, True], ids=["out=input", "out=fp32"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make if [False, True] is IS_HIP_EXTRENSION else [False,]

def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, fp32_output, monkeypatch):
# Mixed-precision output (bf16/bf16/fp32, fp16/fp16/fp32) only goes
# through the CUTLASS / CK grouped GEMM path; the multi-stream cublasLt
# fallback requires A_dt == B_dt == D_dt, and accumulate is incompatible
# with the mixed-precision output path.
if fp32_output and (not use_cutlass or accumulate):
pytest.skip("fp32 output requires use_cutlass=True and accumulate=False")

torch.manual_seed(0)
z, m, k, n = shape

Expand All @@ -3008,10 +3016,12 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
assert m_splits.sum() == m and len(m_splits) == z
m_splits = m_splits.tolist()

out_dtype = torch.float32 if fp32_output else dtype

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output
out = [torch.randn(m, n, dtype=out_dtype, device="cuda")] # output
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = False
single_output = True
Expand All @@ -3020,7 +3030,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad
out = [torch.randn(m, k, dtype=out_dtype, device="cuda")] # dgrad
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = True
single_output = True
Expand All @@ -3029,19 +3039,19 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out = [torch.randn(n, k, dtype=out_dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [o.clone() for o in out]
grad = True
single_output = False

if use_cutlass:
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1")

for i in range(z):
general_gemm(
A[i],
B[i],
dtype,
out_dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
Expand All @@ -3055,7 +3065,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B,
out,
[None] * z,
dtype,
out_dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
Expand All @@ -3074,9 +3084,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

if use_cutlass:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
(is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt))
) ||
(
(A_dt == B_dt) && (A_dt == D_dt) &&
(is_fp16_dtype(A_dt))
(A_dt == B_dt) && is_fp16_dtype(A_dt) &&
(is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32)
);
#else
auto A_type = get_cuda_dtype(inputA->data.dtype);
Expand Down
Loading