diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..46da76448 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2999,7 +2999,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) -def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): +@pytest.mark.parametrize("fp32_output", [False, True], ids=["out=input", "out=fp32"]) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, fp32_output, monkeypatch): + # Mixed-precision output (bf16/bf16/fp32, fp16/fp16/fp32) only goes + # through the CUTLASS / CK grouped GEMM path; the multi-stream cublasLt + # fallback requires A_dt == B_dt == D_dt, and accumulate is incompatible + # with the mixed-precision output path. + if fp32_output and (not use_cutlass or accumulate): + pytest.skip("fp32 output requires use_cutlass=True and accumulate=False") + torch.manual_seed(0) z, m, k, n = shape @@ -3008,10 +3016,12 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): assert m_splits.sum() == m and len(m_splits) == z m_splits = m_splits.tolist() + out_dtype = torch.float32 if fp32_output else dtype + if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out = [torch.randn(m, n, dtype=out_dtype, device="cuda")] # output out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = False single_output = True @@ -3020,7 +3030,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B = list( torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) ) # grad_output - out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out = [torch.randn(m, k, dtype=out_dtype, device="cuda")] # dgrad out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = True single_output = True @@ -3029,19 +3039,19 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B = list( torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) ) # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out = [torch.randn(n, k, dtype=out_dtype, device="cuda") for _ in range(z)] # wgrad out_ref = [o.clone() for o in out] grad = True single_output = False if use_cutlass: - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") for i in range(z): general_gemm( A[i], B[i], - dtype, + out_dtype, grad=grad, accumulate=accumulate, layout=layout, @@ -3055,7 +3065,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B, out, [None] * z, - dtype, + out_dtype, m_splits=m_splits, grad=grad, accumulate=accumulate, @@ -3074,9 +3084,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): else: torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) - if use_cutlass: - os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..44735d9bb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1167,8 +1167,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) || ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) + (A_dt == B_dt) && is_fp16_dtype(A_dt) && + (is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32) ); #else auto A_type = get_cuda_dtype(inputA->data.dtype);