diff --git a/.gitignore b/.gitignore index dc825ff..b68a335 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ build *.egg-info dist +*.so +traces __pycache__ \ No newline at end of file diff --git a/benchmark.py b/benchmark.py index dc97582..155ed6c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,28 +1,203 @@ import torch import grouped_gemm as gg +def benchmark(name, func, x, w, batch_sizes, iterations=50, trans_a=False, trans_b=True): + print(f"Profiling {name}...") -if __name__ == '__main__': - # Mixtral 8x7B sizes. - M = 16384 - K = 4096 - N = 14336 - E = 8 - x = torch.rand(M, K, dtype=torch.bfloat16, device='cuda') - w = torch.rand(E, K, N, dtype=torch.bfloat16, device='cuda') - - x.requires_grad_(True) - w.requires_grad_(True) - - batch_sizes = torch.tensor([M//E]*E) + # warmup + for _ in range(10): + out = func(x, w, batch_sizes, trans_b=trans_b) + out.sum().backward() + x.grad = None + w.grad = None + torch.cuda.synchronize() - iterations = 50 with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: for _ in range(iterations): - out = gg.ops.gmm(x, w, batch_sizes) - grad = out.sum().backward() + out = func(x, w, batch_sizes, trans_b=trans_b) + out.sum().backward() + x.grad = None + w.grad = None torch.cuda.synchronize() - device_time = prof.key_averages().total_average().device_time_total - print(f"Total gpu time: {device_time/1000:.2f} ms") - print(f"time per iteration: {device_time/iterations/1000:.2f} ms") + + total_ms = prof.key_averages().total_average().device_time_total / 1000 + avg_ms = total_ms / iterations + + # FLOPs + TFLOPs + fwd_flops = grouped_gemm_flops(x, w, batch_sizes, trans_a=trans_a, trans_b=trans_b) + avg_s = avg_ms * 1e-3 + fwd_tflops = fwd_flops / avg_s / 1e12 + + # Your timing includes backward. If backward does two GEMMs + # (dA and dB), total GEMM FLOPs ≈ 3x forward. + fwd_bwd_tflops_est = 3.0 * fwd_tflops + + print(f" -> Total GPU time: {total_ms:.2f} ms") + print(f" -> Time per step: {avg_ms:.3f} ms") + print(f" -> Forward TFLOPs: {fwd_tflops:.2f}") + print(f" -> Fwd+Bwd TFLOPs (est): {fwd_bwd_tflops_est:.2f}") + + prof.export_chrome_trace(f'traces/{name}_trace.json') + return avg_ms, fwd_tflops + +def grouped_gemm_flops(x, w, batch_sizes, trans_a=False, trans_b=True): + """ + Returns forward FLOPs for grouped GEMM. + x: (tokens, K) when trans_a=False, else packed (sum k_i, m) + w: + - trans_b=True: (E, N, K) + - trans_b=False: (E, K, N) + batch_sizes: (E,) giving m_i (if trans_a=False) or k_i (if trans_a=True) + """ + bs = batch_sizes.detach().to("cpu", non_blocking=True).tolist() + E = len(bs) + + if not trans_a: + # fixed K, variable m_i + K = x.shape[1] + N = w.shape[1] if trans_b else w.shape[2] + M_total = sum(bs) + flops = 2.0 * M_total * K * N + else: + # variable K_i, fixed m and n + # x is logically A_i: (K_i, m) because A is transposed + m = x.shape[1] + n = w.shape[1] # b is (tokens, n) in this mode + flops = 0.0 + for k_i in bs: + flops += 2.0 * m * n * k_i + + return flops + +def make_batch_sizes(M, E, mode="uniform", device="cpu"): + if mode == "uniform": + m = M // E + sizes = [m] * E + sizes[0] += M - m * E + return torch.tensor(sizes, dtype=torch.long, device=device) + + elif mode == "mild_skew": + alpha = torch.full((E,), 2.0) + probs = torch.distributions.Dirichlet(alpha).sample() + sizes = torch.floor(probs * M).long() + diff = M - int(sizes.sum().item()) + for _ in range(abs(diff)): + idx = torch.randint(0, E, ()) + sizes[idx] += 1 if diff > 0 else -1 + return sizes.to(device) + + elif mode == "extreme_skew": + hot_E = max(1, E // 8) + cold_E = E - hot_E + + hot_tokens = int(0.8 * M) + cold_tokens = M - hot_tokens + + hot_base = hot_tokens // hot_E + cold_base = cold_tokens // max(cold_E, 1) + + sizes = [] + for i in range(E): + if i < hot_E: + sizes.append(hot_base) + else: + sizes.append(cold_base) + + sizes = torch.tensor(sizes, dtype=torch.long) + diff = M - int(sizes.sum().item()) + for _ in range(abs(diff)): + idx = torch.randint(0, E, ()) + sizes[idx] += 1 if diff > 0 else -1 + return sizes.to(device) + +if __name__ == '__main__': + + model_config_dict = { + "Qwen/Qwen3-30B-A3B":{ + "num_experts_per_tok": 8, + "hidden_size": 2048, + "num_experts": 128, + "moe_intermediate_size": 768, + }, + } + + model_name = "Qwen/Qwen3-30B-A3B" + model_config = model_config_dict[model_name] + seqlen = 8192 + test_case = "up_proj" + + M = seqlen * model_config["num_experts_per_tok"] + E = model_config["num_experts"] + + if test_case == "up_proj": + K = model_config["hidden_size"] + N = 2 * model_config["moe_intermediate_size"] + elif test_case == "down_proj": + K = model_config["moe_intermediate_size"] + N = model_config["hidden_size"] + + print(f"Config: {test_case} | Tokens: {M} | Experts: {E} | Shape: K={K}, N={N}") + + torch.manual_seed(42) + + x = torch.rand(M, K, dtype=torch.bfloat16, device='cuda', requires_grad=True) + w = torch.rand(E, N, K, dtype=torch.bfloat16, device='cuda', requires_grad=True) + + modes = ["uniform", "mild_skew", "extreme_skew"] + + for mode in modes: + print("\n" + "=" * 30) + print(f"Workload mode: {mode}") + batch_sizes = make_batch_sizes(M, E, mode=mode, device='cpu') + print(f"Statistics: min={int(batch_sizes.min())}, " + f"max={int(batch_sizes.max())}, " + f"mean={batch_sizes.float().mean().item():.1f}, " + f"var={batch_sizes.float().var(unbiased=False).item():.1f}" + ) + print("=" * 30) + + x.grad = None + w.grad = None + + time_base, tflops_base = benchmark( + "cuBLAS (Base)", + gg.ops.gmm_base, + x, w, batch_sizes + ) + + print("-" * 30) + + time_cublas, tflops_cublas = benchmark( + "cuBLAS (Batched)", + gg.ops.gmm_cuBLAS, + x, w, batch_sizes + ) + + print("-" * 30) + + time_CUTLASS, tflops_CUTLASS = benchmark( + "CUTLASS sm80", + gg.ops.gmm_CUTLASS_sm80, + x, w, batch_sizes + ) + + print("-" * 30) + + time_CUTLASS_sm90_cooperative, tflops_CUTLASS_sm90_cooperative = benchmark( + "CUTLASS sm90 cooperative", + gg.ops.gmm_CUTLASS_sm90_cooperative, + x, w, batch_sizes + ) + + print("-" * 30) + + time_CUTLASS_sm90_pingpong, tflops_CUTLASS_sm90_pingpong = benchmark( + "CUTLASS sm90 pingpong", + gg.ops.gmm_CUTLASS_sm90_pingpong, + x, w, batch_sizes + ) + + print("=" * 30) + # print(f"Speedup: {time_base / time_cublas:.2f}x") + # print("=" * 30) \ No newline at end of file diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index e550877..d1a91d5 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -8,12 +8,6 @@ #include #include -#include "cutlass/bfloat16.h" -#include "cutlass/complex.h" -#include "cutlass/gemm/kernel/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" -#include "cutlass/gemm/device/gemm_grouped.h" - #include namespace grouped_gemm { @@ -35,341 +29,6 @@ namespace grouped_gemm { #define GROUPED_GEMM_STRINGIFY(x) \ GROUPED_GEMM_STRINGIFY_HELPER(x) -template -using GroupedGemmInputLayout = std::conditional_t; - -using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< - ::cutlass::arch::OpClassTensorOp, - ::cutlass::arch::Sm80, - ::cutlass::bfloat16_t, - ::cutlass::bfloat16_t, - ::cutlass::bfloat16_t, - float ->; - -// TODO(tgale): Update this for SM90 when it's supported by CUTLASS. -template -using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - // A operand. - ::cutlass::bfloat16_t, - GroupedGemmInputLayout, - ::cutlass::ComplexTransform::kNone, - GroupedGemmConfig::kAlignmentA, - // B operand. - ::cutlass::bfloat16_t, - GroupedGemmInputLayout, - ::cutlass::ComplexTransform::kNone, - GroupedGemmConfig::kAlignmentB, - // C operand. - ::cutlass::bfloat16_t, - ::cutlass::layout::RowMajor, - float, - ::cutlass::arch::OpClassTensorOp, - ::cutlass::arch::Sm80, - GroupedGemmConfig::ThreadblockShape, - GroupedGemmConfig::WarpShape, - GroupedGemmConfig::InstructionShape, - GroupedGemmConfig::EpilogueOutputOp, - // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. - // This parameter is passed in at present to match the APIs of other kernels. The parameter - // is unused within the kernel. - ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, - // TODO(tgale): Tune this for SM90. - GroupedGemmConfig::kStages>::GemmKernel; - -template -using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>; - -template -torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) { - size_t bytes = x.size() * sizeof(T); - auto options = torch::TensorOptions().dtype(torch::kInt8).device(device); - torch::Tensor out = torch::empty(bytes, options); - - CUDA_CALL(cudaMemcpyAsync(out.data_ptr(), - x.data(), bytes, - cudaMemcpyHostToDevice, - c10::cuda::getCurrentCUDAStream())); - return out; -} - -template -static void ReorderArray(T* data, const std::vector& indices) { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(data, data + indices.size()); - for (size_t i = 0; i < indices.size(); ++i) { - data[i] = copy.at(indices[i]); - } -} - -template -torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) { - return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device)); -} - -struct RawGemmArguments { - torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes; - int threadblock_count{}; -}; - -template < - typename Gemm, - typename ElementA, typename ElementB, typename ElementC -> -RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) { - TORCH_CHECK( - num_experts <= kMaxExperts, - "At most ", kMaxExperts, - " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts - ); - - return RawGemmArguments { - .lda = TypedEmpty(num_experts, device), - .ldb = TypedEmpty(num_experts, device), - .ldc = TypedEmpty(num_experts, device), - .ptr_a = TypedEmpty(num_experts, device), - .ptr_b = TypedEmpty(num_experts, device), - .ptr_c = TypedEmpty(num_experts, device), - .problem_sizes = TypedEmpty(num_experts, device), - - // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here. - .threadblock_count = Gemm::sufficient(), - }; -} - -template < - bool kDynamicK, - typename Gemm, - typename ElementA, typename ElementB, typename ElementC, - typename LayoutA, typename LayoutB, typename LayoutC -> -RawGemmArguments MakeArgumentsOnHost(torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - torch::Tensor batch_sizes, - ::cutlass::gemm::GemmCoord coord_template, - int64_t num_experts) { - std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts); - - // Create the host arrays of leading dimension data and pointer data. - std::vector lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts); - int64_t elements_a = 0, elements_b = 0, elements_c = 0; - - std::vector ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts); - - for (int i = 0; i < num_experts; ++i) { - auto& problem = problem_sizes_host[i]; - problem = coord_template; - (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr()[i]; - - lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); - ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); - ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); - - ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a; - ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b; - ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c; - - elements_a += problem.m() * problem.k(); - elements_b += problem.k() * problem.n(); - elements_c += problem.m() * problem.n(); - - if (problem.k() == 0) { - // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593. - // Until a fix is available on the CUTLASS side, handle these problems by ourselves: - // * set the output to zero with `cudaMemsetAsync()` - // * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero) - CUDA_CALL(cudaMemsetAsync(ptr_c_host[i], - 0, - problem.m() * problem.n() * sizeof(ElementC), - c10::cuda::getCurrentCUDAStream())); - - problem.m() = 0; - problem.n() = 0; - } - } - - // Only sort problems when K are different - if (kDynamicK) { - std::vector indices(num_experts); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) { - return problem_sizes_host[i].k() > problem_sizes_host[j].k(); - }); - - ReorderArray(problem_sizes_host.data(), indices); - ReorderArray(lda_host.data(), indices); - ReorderArray(ldb_host.data(), indices); - ReorderArray(ldc_host.data(), indices); - ReorderArray(ptr_a_host.data(), indices); - ReorderArray(ptr_b_host.data(), indices); - ReorderArray(ptr_c_host.data(), indices); - } - - // Copy the problem sizes, pointers and leading dimension data to the device. - return RawGemmArguments { - .lda = CopyToDevice(lda_host, a.device()), - .ldb = CopyToDevice(ldb_host, a.device()), - .ldc = CopyToDevice(ldc_host, a.device()), - .ptr_a = CopyToDevice(ptr_a_host, a.device()), - .ptr_b = CopyToDevice(ptr_b_host, a.device()), - .ptr_c = CopyToDevice(ptr_c_host, a.device()), - .problem_sizes = CopyToDevice(problem_sizes_host, a.device()), - - // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that. - .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts), - }; -} - -template < - bool kDynamicK, - typename Gemm, - typename ElementA, typename ElementB, typename ElementC, - typename LayoutA, typename LayoutB, typename LayoutC -> -typename Gemm::Arguments MakeArguments(torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - torch::Tensor batch_sizes, - ::cutlass::gemm::GemmCoord coord_template, - int64_t num_experts) { - RawGemmArguments raw_args; - if (batch_sizes.is_cuda()) { - raw_args = MakeArgumentsOnDevice< - Gemm, ElementA, ElementB, ElementC - >(num_experts, a.device()); - } else { - raw_args = MakeArgumentsOnHost< - kDynamicK, - Gemm, - ElementA, ElementB, ElementC, - LayoutA, LayoutB, LayoutC - >(a, b, c, batch_sizes, coord_template, num_experts); - } - - // Validate the result. - if (!raw_args.threadblock_count) { - TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); - } - - typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f); - // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all, - // so we can safely pass `nullptr` for `host_problem_sizes`. - // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we - // know the problem dimensions on the host. - typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(), - (int)num_experts, - (int)raw_args.threadblock_count, - epilogue_op, - (ElementA**)raw_args.ptr_a.data_ptr(), - (ElementB**)raw_args.ptr_b.data_ptr(), - (ElementC**)raw_args.ptr_c.data_ptr(), - (ElementC**)raw_args.ptr_c.data_ptr(), - /*lda=*/(int64_t*)raw_args.lda.data_ptr(), - /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(), - /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(), - /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(), - /*host_problem_sizes=*/nullptr); - return arguments; -} - -template < - bool trans_a, - typename ElementA, typename ElementB, typename ElementC, - typename LayoutA, typename LayoutB, typename LayoutC, - typename Arguments -> -void FillCutlassArguments(int num_experts, - torch::Tensor batch_sizes, - torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - const Arguments& arguments, - ::cutlass::gemm::GemmCoord coord_template) { - // Convert the batch sizes to the format CUTLASS understands on the device. - // Use a single block here because: - // * the number of elements to process is microscopically small - // * we don't need any additional global memory - FillArguments< - /*kDynamicK*/trans_a, - ElementA, ElementB, ElementC, - LayoutA, LayoutB, LayoutC - ><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>( - num_experts, batch_sizes.data_ptr(), - (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(), - arguments, coord_template - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void RemoveK0Problems(int num_experts, const Args& arguments) { - // For zeroing out the outputs (which might be arbitrarily large), we want to use - // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth. - // `arguments.threadblock_count`, which we will use for the grouped GEMM proper, - // should be a good approximation for this. - // When the `k=0` case is fixed in CUTLASS, we can completely remove this function. - ZeroOutK0Outputs<><<< - arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream() - >>>( - num_experts, arguments - ); - IgnoreK0Problems<><<< - 1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream() - >>>( - num_experts, arguments - ); -} - -template -torch::Tensor CutlassGroupedGemm(torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - torch::Tensor batch_sizes, - ::cutlass::gemm::GemmCoord coord_template) { - using Gemm = GemmGrouped; - using LayoutA = typename Gemm::LayoutA; - using LayoutB = typename Gemm::LayoutB; - using LayoutC = typename Gemm::LayoutC; - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - - Gemm gemm; - int64_t num_experts = batch_sizes.size(0); - auto arguments = MakeArguments< - /*kDynamicK*/trans_a, - Gemm, - ElementA, ElementB, ElementC, - LayoutA, LayoutB, LayoutC - >(a, b, c, batch_sizes, coord_template, num_experts); - int64_t workspace_size = gemm.get_workspace_size(arguments); - auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); - torch::Tensor workspace = torch::empty(workspace_size, options); - - if (batch_sizes.is_cuda()) { - FillCutlassArguments< - trans_a, - ElementA, ElementB, ElementC, - LayoutA, LayoutB, LayoutC - >(num_experts, batch_sizes, a, b, c, arguments, coord_template); - - RemoveK0Problems<>(num_experts, arguments); - } - - // Initialize the kernel. - if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) { - TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); - } - - // Execute the kernel in the current stream. - if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) { - TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); - } - return c; -} - void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a, c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b, c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) { @@ -467,7 +126,7 @@ void GroupedGemmVariableK(torch::Tensor a, // assumed to be batched with fixed sized batches. // // TODO(tgale): Validate alignment is true for every batch element. -void GroupedGemm(torch::Tensor a, +void GroupedGemm_base(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, @@ -475,13 +134,8 @@ void GroupedGemm(torch::Tensor a, // NOTE: We only support 'trans_a' or 'trans_b', not both. TORCH_CHECK(!(trans_a && trans_b)); -#if !defined(GROUPED_GEMM_CUTLASS) // No way to run cuBLAS kernels if the problem dimensions are not known on the host. TORCH_CHECK(batch_sizes.is_cpu()); -#else - // CUTLASS can handle both CPU- and CUDA-resident problem dimensions. - TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu()); -#endif TORCH_CHECK(batch_sizes.ndimension() == 1); TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64); @@ -491,14 +145,10 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(a.ndimension() == 2); TORCH_CHECK(a.scalar_type() == torch::kBFloat16); -#if !defined(GROUPED_GEMM_CUTLASS) if (trans_a) { - // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS - // for the rest of the op. GroupedGemmVariableK(a, b, c, batch_sizes); return; } -#endif TORCH_CHECK(b.is_cuda()); TORCH_CHECK(c.is_cuda()); @@ -539,28 +189,8 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(b.is_contiguous()); TORCH_CHECK(c.is_contiguous()); -#if !defined(GROUPED_GEMM_CUTLASS) CublasGroupedGemm(a, b, c, batch_sizes, trans_b); return; -#else - // The `coord_template` argument contains `kDynamicDim` as one of its dimensions - // as a placeholder. This placeholder is later expanded into the actual dimension - // for every element of the batch, either on the host or on the device - // (if we can't do in on the host). - const auto coord_template = trans_a - ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim) - : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in); - if (trans_a) { - CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); - return; - } - if (trans_b) { - CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); - return; - } - CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); - return; -#endif } } // namespace grouped_gemm diff --git a/csrc/grouped_gemm.h b/csrc/grouped_gemm.h index ae36a62..a5d1e3f 100644 --- a/csrc/grouped_gemm.h +++ b/csrc/grouped_gemm.h @@ -2,7 +2,7 @@ namespace grouped_gemm { -void GroupedGemm(torch::Tensor a, +void GroupedGemm_base(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, diff --git a/csrc/grouped_gemm_cublas.cu b/csrc/grouped_gemm_cublas.cu new file mode 100644 index 0000000..9f4ae88 --- /dev/null +++ b/csrc/grouped_gemm_cublas.cu @@ -0,0 +1,381 @@ +#include "grouped_gemm_cublas.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace grouped_gemm { + +#define CUDA_CALL(expr) \ + do { \ + cudaError_t _status = (expr); \ + TORCH_CHECK(_status == cudaSuccess, \ + "CUDA Error: ", cudaGetErrorString(_status)); \ + } while (0) + +#define CUBLAS_CALL(expr) \ + do { \ + cublasStatus_t _status = (expr); \ + TORCH_CHECK(_status == CUBLAS_STATUS_SUCCESS, \ + "cuBLAS Error: ", static_cast(_status)); \ + } while (0) + +static torch::Tensor make_device_pointer_array(const std::vector& host, + const torch::Device& device) { + auto options = torch::TensorOptions().dtype(torch::kInt64).device(device); + torch::Tensor t = torch::empty({static_cast(host.size())}, options); + + CUDA_CALL(cudaMemcpyAsync( + t.data_ptr(), + host.data(), + host.size() * sizeof(void*), + cudaMemcpyHostToDevice, + at::cuda::getCurrentCUDAStream())); + + return t; +} + +static void CublasGroupedGemm_FixedK(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_b) { + TORCH_CHECK(a.is_cuda() && b.is_cuda() && c.is_cuda(), + "All tensors must be CUDA"); + TORCH_CHECK(a.scalar_type() == torch::kBFloat16 && + b.scalar_type() == torch::kBFloat16 && + c.scalar_type() == torch::kBFloat16, + "a, b, c must be bfloat16"); + TORCH_CHECK(a.dim() == 2, "a must be 2D (tokens, hidden_in)"); + TORCH_CHECK(b.dim() == 3, "b must be 3D (num_experts, *, *)"); + TORCH_CHECK(c.dim() == 2, "c must be 2D (tokens, hidden_out)"); + TORCH_CHECK(a.is_contiguous() && b.is_contiguous() && c.is_contiguous(), + "a, b, c must be contiguous"); + + TORCH_CHECK(batch_sizes.dim() == 1, + "batch_sizes must be 1D (num_experts,)"); + TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64, + "batch_sizes must be int64"); + + const int64_t num_experts = batch_sizes.size(0); + const int64_t tokens = a.size(0); + const int64_t hidden_in = a.size(1); + + TORCH_CHECK(b.size(0) == num_experts, + "b.size(0) must equal num_experts"); + + const int64_t b_rows = b.size(1); + const int64_t b_cols = b.size(2); + + const int64_t hidden_out = trans_b ? b_rows : b_cols; + + TORCH_CHECK(hidden_in == (trans_b ? b_cols : b_rows), + "Incompatible shapes between a and b"); + TORCH_CHECK(c.size(0) == tokens && c.size(1) == hidden_out, + "c must have shape (tokens, hidden_out)"); + + auto batch_sizes_cpu = batch_sizes.to(torch::kCPU); + const int64_t* bs_ptr = batch_sizes_cpu.data_ptr(); + + int64_t tokens_sum = 0; + for (int64_t i = 0; i < num_experts; ++i) { + TORCH_CHECK(bs_ptr[i] >= 0, "batch_sizes must be non-negative"); + tokens_sum += bs_ptr[i]; + } + TORCH_CHECK(tokens_sum == tokens, + "Sum of batch_sizes must equal total tokens"); + + const int m_gemm = static_cast(trans_b ? b_rows : b_cols); + const int k_gemm = static_cast(trans_b ? b_cols : b_rows); + + TORCH_CHECK(m_gemm > 0 && k_gemm > 0, + "Invalid GEMM dimensions m or k"); + + const bool trans_a = false; + const int lda_val = k_gemm; + const int ldb_val = trans_b ? k_gemm : m_gemm; + const int ldc_val = static_cast(hidden_out); + + std::vector transa_array(num_experts); + std::vector transb_array(num_experts); + std::vector m_array(num_experts); + std::vector n_array(num_experts); + std::vector k_array(num_experts); + std::vector lda_array(num_experts); + std::vector ldb_array(num_experts); + std::vector ldc_array(num_experts); + std::vector group_size(num_experts, 1); + std::vector alpha_array(num_experts, 1.0f); + std::vector beta_array(num_experts, 0.0f); + + const cublasOperation_t opA = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t opB = CUBLAS_OP_N; + + for (int64_t i = 0; i < num_experts; ++i) { + const int64_t tokens_i = bs_ptr[i]; + TORCH_CHECK(tokens_i <= std::numeric_limits::max(), + "tokens per expert exceed INT32 range"); + + transa_array[i] = opA; + transb_array[i] = opB; + m_array[i] = m_gemm; + n_array[i] = static_cast(tokens_i); + k_array[i] = k_gemm; + lda_array[i] = ldb_val; + ldb_array[i] = lda_val; + ldc_array[i] = ldc_val; + } + + std::vector Aarray_host(num_experts); + std::vector Barray_host(num_experts); + std::vector Carray_host(num_experts); + + const auto device = a.device(); + const auto* a_ptr_base = a.data_ptr(); + const auto* b_ptr_base = b.data_ptr(); + auto* c_ptr_base = c.data_ptr(); + + const int64_t stride_b = b_rows * b_cols; + + int64_t token_offset = 0; + for (int64_t i = 0; i < num_experts; ++i) { + const int64_t tokens_i = bs_ptr[i]; + + const c10::BFloat16* b_i = b_ptr_base + i * stride_b; + const c10::BFloat16* a_i = a_ptr_base + token_offset * hidden_in; + c10::BFloat16* c_i = c_ptr_base + token_offset * hidden_out; + + Aarray_host[i] = const_cast(b_i); + Barray_host[i] = const_cast(a_i); + Carray_host[i] = c_i; + + token_offset += tokens_i; + } + + TORCH_CHECK(token_offset == tokens, + "Internal error: token_offset mismatch"); + + torch::Tensor Aarray_dev = make_device_pointer_array(Aarray_host, device); + torch::Tensor Barray_dev = make_device_pointer_array(Barray_host, device); + torch::Tensor Carray_dev = make_device_pointer_array(Carray_host, device); + + const void* const* d_Aarray = + reinterpret_cast(Aarray_dev.data_ptr()); + const void* const* d_Barray = + reinterpret_cast(Barray_dev.data_ptr()); + void* const* d_Carray = + reinterpret_cast(Carray_dev.data_ptr()); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + CUBLAS_CALL(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST)); + + const cudaDataType Atype = CUDA_R_16BF; + const cudaDataType Btype = CUDA_R_16BF; + const cudaDataType Ctype = CUDA_R_16BF; + const cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + + CUBLAS_CALL(cublasGemmGroupedBatchedEx( + handle, + transa_array.data(), + transb_array.data(), + m_array.data(), + n_array.data(), + k_array.data(), + alpha_array.data(), + d_Aarray, + Atype, + lda_array.data(), + d_Barray, + Btype, + ldb_array.data(), + beta_array.data(), + d_Carray, + Ctype, + ldc_array.data(), + static_cast(num_experts), + group_size.data(), + computeType)); +} + +void CublasGroupedGemm_VariableK(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes) { + TORCH_CHECK(a.is_cuda() && b.is_cuda() && c.is_cuda(), + "All tensors must be CUDA"); + TORCH_CHECK(a.scalar_type() == torch::kBFloat16 && + b.scalar_type() == torch::kBFloat16 && + c.scalar_type() == torch::kBFloat16, + "a, b, c must be bfloat16"); + TORCH_CHECK(a.dim() == 2, "a must be 2D"); + TORCH_CHECK(b.dim() == 2, "b must be 2D"); + TORCH_CHECK(c.dim() == 3, "c must be 3D"); + TORCH_CHECK(a.is_contiguous() && b.is_contiguous() && c.is_contiguous(), + "a, b, c must be contiguous"); + + TORCH_CHECK(batch_sizes.dim() == 1, + "batch_sizes must be 1D (num_experts,)"); + TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64, + "batch_sizes must be int64"); + + const int64_t num_experts = batch_sizes.size(0); + const int64_t m = a.size(1); + const int64_t n = b.size(1); + + TORCH_CHECK(c.size(0) == num_experts && + c.size(1) == m && + c.size(2) == n, + "c must have shape (num_experts, m, n)"); + + auto batch_sizes_cpu = batch_sizes.to(torch::kCPU); + const int64_t* k_host = batch_sizes_cpu.data_ptr(); + + int64_t sum_k = 0; + for (int64_t i = 0; i < num_experts; ++i) { + TORCH_CHECK(k_host[i] >= 0, "batch_sizes must be non-negative"); + sum_k += k_host[i]; + } + TORCH_CHECK(a.size(0) == sum_k, + "a.size(0) must equal sum of batch_sizes"); + TORCH_CHECK(b.size(0) == sum_k, + "b.size(0) must equal sum of batch_sizes"); + + const int m_gemm = static_cast(n); + const int n_gemm = static_cast(m); + + TORCH_CHECK(m_gemm > 0 && n_gemm > 0, + "Invalid GEMM dimensions m or n"); + + std::vector transa_array(num_experts); + std::vector transb_array(num_experts); + std::vector m_array(num_experts); + std::vector n_array(num_experts); + std::vector k_array(num_experts); + std::vector lda_array(num_experts); + std::vector ldb_array(num_experts); + std::vector ldc_array(num_experts); + std::vector group_size(num_experts, 1); + std::vector alpha_array(num_experts, 1.0f); + std::vector beta_array(num_experts, 0.0f); + + const cublasOperation_t opA = CUBLAS_OP_N; + const cublasOperation_t opB = CUBLAS_OP_T; + + const int lda_val = n_gemm; + const int ldb_val = m_gemm; + const int ldc_val = static_cast(n); + + for (int64_t i = 0; i < num_experts; ++i) { + TORCH_CHECK(k_host[i] <= std::numeric_limits::max(), + "k_i exceeds INT32 range"); + + transa_array[i] = opA; + transb_array[i] = opB; + m_array[i] = m_gemm; + n_array[i] = n_gemm; + k_array[i] = static_cast(k_host[i]); + lda_array[i] = ldb_val; + ldb_array[i] = lda_val; + ldc_array[i] = ldc_val; + } + + std::vector Aarray_host(num_experts); + std::vector Barray_host(num_experts); + std::vector Carray_host(num_experts); + + const auto device = a.device(); + const auto* a_ptr_base = a.data_ptr(); + const auto* b_ptr_base = b.data_ptr(); + auto* c_ptr_base = c.data_ptr(); + + int64_t offset_a = 0; + int64_t offset_b = 0; + + const int64_t c_stride = m * n; + + for (int64_t i = 0; i < num_experts; ++i) { + const int64_t k_i = k_host[i]; + + const c10::BFloat16* a_i = a_ptr_base + offset_a; + const c10::BFloat16* b_i = b_ptr_base + offset_b; + c10::BFloat16* c_i = c_ptr_base + i * c_stride; + + Aarray_host[i] = const_cast(b_i); + Barray_host[i] = const_cast(a_i); + Carray_host[i] = c_i; + + offset_a += k_i * m; + offset_b += k_i * n; + } + + TORCH_CHECK(offset_a == a.size(0) * m, + "Internal error: offset_a mismatch"); + TORCH_CHECK(offset_b == b.size(0) * n, + "Internal error: offset_b mismatch"); + + torch::Tensor Aarray_dev = make_device_pointer_array(Aarray_host, device); + torch::Tensor Barray_dev = make_device_pointer_array(Barray_host, device); + torch::Tensor Carray_dev = make_device_pointer_array(Carray_host, device); + + const void* const* d_Aarray = + reinterpret_cast(Aarray_dev.data_ptr()); + const void* const* d_Barray = + reinterpret_cast(Barray_dev.data_ptr()); + void* const* d_Carray = + reinterpret_cast(Carray_dev.data_ptr()); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + CUBLAS_CALL(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST)); + + const cudaDataType Atype = CUDA_R_16BF; + const cudaDataType Btype = CUDA_R_16BF; + const cudaDataType Ctype = CUDA_R_16BF; + const cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + + CUBLAS_CALL(cublasGemmGroupedBatchedEx( + handle, + transa_array.data(), + transb_array.data(), + m_array.data(), + n_array.data(), + k_array.data(), + alpha_array.data(), + d_Aarray, + Atype, + lda_array.data(), + d_Barray, + Btype, + ldb_array.data(), + beta_array.data(), + d_Carray, + Ctype, + ldc_array.data(), + static_cast(num_experts), + group_size.data(), + computeType)); +} + +void GroupedGemm_cuBLAS(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, + bool trans_b) { + TORCH_CHECK(!(trans_a && trans_b), + "Only one of trans_a / trans_b may be true"); + + if (trans_a) { + CublasGroupedGemm_VariableK(a, b, c, batch_sizes); + } else { + CublasGroupedGemm_FixedK(a, b, c, batch_sizes, trans_b); + } +} + +} // namespace grouped_gemm diff --git a/csrc/grouped_gemm_cublas.h b/csrc/grouped_gemm_cublas.h new file mode 100644 index 0000000..848d4f3 --- /dev/null +++ b/csrc/grouped_gemm_cublas.h @@ -0,0 +1,11 @@ +#include + +namespace grouped_gemm { + +void GroupedGemm_cuBLAS(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b); + +} // namespace grouped_gemm diff --git a/csrc/grouped_gemm_cutlass_sm80.cu b/csrc/grouped_gemm_cutlass_sm80.cu new file mode 100644 index 0000000..d023259 --- /dev/null +++ b/csrc/grouped_gemm_cutlass_sm80.cu @@ -0,0 +1,454 @@ +#include "grouped_gemm_cutlass_sm80.h" +#include "fill_arguments.cuh" + +#include +#include +#include +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include + +namespace grouped_gemm { + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +#define CUBLAS_CALL(code) \ + do { \ + cublasStatus_t status = code; \ + TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \ + } while (0) + +#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x +#define GROUPED_GEMM_STRINGIFY(x) \ + GROUPED_GEMM_STRINGIFY_HELPER(x) + +template +using GroupedGemmInputLayout = std::conditional_t; + +using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< + ::cutlass::arch::OpClassTensorOp, + ::cutlass::arch::Sm80, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + float +>; + +// TODO(tgale): Update this for SM90 when it's supported by CUTLASS. +template +using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + // A operand. + ::cutlass::bfloat16_t, + GroupedGemmInputLayout, + ::cutlass::ComplexTransform::kNone, + GroupedGemmConfig::kAlignmentA, + // B operand. + ::cutlass::bfloat16_t, + GroupedGemmInputLayout, + ::cutlass::ComplexTransform::kNone, + GroupedGemmConfig::kAlignmentB, + // C operand. + ::cutlass::bfloat16_t, + ::cutlass::layout::RowMajor, + float, + ::cutlass::arch::OpClassTensorOp, + ::cutlass::arch::Sm80, + GroupedGemmConfig::ThreadblockShape, + GroupedGemmConfig::WarpShape, + GroupedGemmConfig::InstructionShape, + GroupedGemmConfig::EpilogueOutputOp, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + // TODO(tgale): Tune this for SM90. + GroupedGemmConfig::kStages>::GemmKernel; + +template +using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>; + +template +torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) { + size_t bytes = x.size() * sizeof(T); + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device); + torch::Tensor out = torch::empty(bytes, options); + + CUDA_CALL(cudaMemcpyAsync(out.data_ptr(), + x.data(), bytes, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + return out; +} + +template +static void ReorderArray(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(data, data + indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + data[i] = copy.at(indices[i]); + } +} + +template +torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) { + return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device)); +} + +struct RawGemmArguments { + torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes; + int threadblock_count{}; +}; + +template < + typename Gemm, + typename ElementA, typename ElementB, typename ElementC +> +RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) { + TORCH_CHECK( + num_experts <= kMaxExperts, + "At most ", kMaxExperts, + " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts + ); + + return RawGemmArguments { + .lda = TypedEmpty(num_experts, device), + .ldb = TypedEmpty(num_experts, device), + .ldc = TypedEmpty(num_experts, device), + .ptr_a = TypedEmpty(num_experts, device), + .ptr_b = TypedEmpty(num_experts, device), + .ptr_c = TypedEmpty(num_experts, device), + .problem_sizes = TypedEmpty(num_experts, device), + + // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here. + .threadblock_count = Gemm::sufficient(), + }; +} + +template < + bool kDynamicK, + typename Gemm, + typename ElementA, typename ElementB, typename ElementC, + typename LayoutA, typename LayoutB, typename LayoutC +> +RawGemmArguments MakeArgumentsOnHost(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + ::cutlass::gemm::GemmCoord coord_template, + int64_t num_experts) { + std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts); + + // Create the host arrays of leading dimension data and pointer data. + std::vector lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts); + int64_t elements_a = 0, elements_b = 0, elements_c = 0; + + std::vector ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts); + + for (int i = 0; i < num_experts; ++i) { + auto& problem = problem_sizes_host[i]; + problem = coord_template; + (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr()[i]; + + lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a; + ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b; + ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c; + + elements_a += problem.m() * problem.k(); + elements_b += problem.k() * problem.n(); + elements_c += problem.m() * problem.n(); + + if (problem.k() == 0) { + // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593. + // Until a fix is available on the CUTLASS side, handle these problems by ourselves: + // * set the output to zero with `cudaMemsetAsync()` + // * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero) + CUDA_CALL(cudaMemsetAsync(ptr_c_host[i], + 0, + problem.m() * problem.n() * sizeof(ElementC), + c10::cuda::getCurrentCUDAStream())); + + problem.m() = 0; + problem.n() = 0; + } + } + + // Only sort problems when K are different + if (kDynamicK) { + std::vector indices(num_experts); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) { + return problem_sizes_host[i].k() > problem_sizes_host[j].k(); + }); + + ReorderArray(problem_sizes_host.data(), indices); + ReorderArray(lda_host.data(), indices); + ReorderArray(ldb_host.data(), indices); + ReorderArray(ldc_host.data(), indices); + ReorderArray(ptr_a_host.data(), indices); + ReorderArray(ptr_b_host.data(), indices); + ReorderArray(ptr_c_host.data(), indices); + } + + // Copy the problem sizes, pointers and leading dimension data to the device. + return RawGemmArguments { + .lda = CopyToDevice(lda_host, a.device()), + .ldb = CopyToDevice(ldb_host, a.device()), + .ldc = CopyToDevice(ldc_host, a.device()), + .ptr_a = CopyToDevice(ptr_a_host, a.device()), + .ptr_b = CopyToDevice(ptr_b_host, a.device()), + .ptr_c = CopyToDevice(ptr_c_host, a.device()), + .problem_sizes = CopyToDevice(problem_sizes_host, a.device()), + + // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that. + .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts), + }; +} + +template < + bool kDynamicK, + typename Gemm, + typename ElementA, typename ElementB, typename ElementC, + typename LayoutA, typename LayoutB, typename LayoutC +> +typename Gemm::Arguments MakeArguments(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + ::cutlass::gemm::GemmCoord coord_template, + int64_t num_experts) { + RawGemmArguments raw_args; + if (batch_sizes.is_cuda()) { + raw_args = MakeArgumentsOnDevice< + Gemm, ElementA, ElementB, ElementC + >(num_experts, a.device()); + } else { + raw_args = MakeArgumentsOnHost< + kDynamicK, + Gemm, + ElementA, ElementB, ElementC, + LayoutA, LayoutB, LayoutC + >(a, b, c, batch_sizes, coord_template, num_experts); + } + + // Validate the result. + if (!raw_args.threadblock_count) { + TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); + } + + typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f); + // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all, + // so we can safely pass `nullptr` for `host_problem_sizes`. + // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we + // know the problem dimensions on the host. + typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(), + (int)num_experts, + (int)raw_args.threadblock_count, + epilogue_op, + (ElementA**)raw_args.ptr_a.data_ptr(), + (ElementB**)raw_args.ptr_b.data_ptr(), + (ElementC**)raw_args.ptr_c.data_ptr(), + (ElementC**)raw_args.ptr_c.data_ptr(), + /*lda=*/(int64_t*)raw_args.lda.data_ptr(), + /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(), + /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(), + /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(), + /*host_problem_sizes=*/nullptr); + return arguments; +} + +template < + bool trans_a, + typename ElementA, typename ElementB, typename ElementC, + typename LayoutA, typename LayoutB, typename LayoutC, + typename Arguments +> +void FillCutlassArguments(int num_experts, + torch::Tensor batch_sizes, + torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + const Arguments& arguments, + ::cutlass::gemm::GemmCoord coord_template) { + // Convert the batch sizes to the format CUTLASS understands on the device. + // Use a single block here because: + // * the number of elements to process is microscopically small + // * we don't need any additional global memory + FillArguments< + /*kDynamicK*/trans_a, + ElementA, ElementB, ElementC, + LayoutA, LayoutB, LayoutC + ><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>( + num_experts, batch_sizes.data_ptr(), + (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(), + arguments, coord_template + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void RemoveK0Problems(int num_experts, const Args& arguments) { + // For zeroing out the outputs (which might be arbitrarily large), we want to use + // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth. + // `arguments.threadblock_count`, which we will use for the grouped GEMM proper, + // should be a good approximation for this. + // When the `k=0` case is fixed in CUTLASS, we can completely remove this function. + ZeroOutK0Outputs<><<< + arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream() + >>>( + num_experts, arguments + ); + IgnoreK0Problems<><<< + 1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream() + >>>( + num_experts, arguments + ); +} + +template +torch::Tensor CutlassGroupedGemm(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + ::cutlass::gemm::GemmCoord coord_template) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + Gemm gemm; + int64_t num_experts = batch_sizes.size(0); + auto arguments = MakeArguments< + /*kDynamicK*/trans_a, + Gemm, + ElementA, ElementB, ElementC, + LayoutA, LayoutB, LayoutC + >(a, b, c, batch_sizes, coord_template, num_experts); + int64_t workspace_size = gemm.get_workspace_size(arguments); + auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); + torch::Tensor workspace = torch::empty(workspace_size, options); + + if (batch_sizes.is_cuda()) { + FillCutlassArguments< + trans_a, + ElementA, ElementB, ElementC, + LayoutA, LayoutB, LayoutC + >(num_experts, batch_sizes, a, b, c, arguments, coord_template); + + RemoveK0Problems<>(num_experts, arguments); + } + + // Initialize the kernel. + if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) { + TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) { + TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } + return c; +} + +// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is +// assumed to be batched with fixed sized batches. +// +// TODO(tgale): Validate alignment is true for every batch element. +void GroupedGemm_CUTLASS_sm80(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b) { + // NOTE: We only support 'trans_a' or 'trans_b', not both. + TORCH_CHECK(!(trans_a && trans_b)); + + // CUTLASS can handle both CPU- and CUDA-resident problem dimensions. + TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu()); + TORCH_CHECK(batch_sizes.ndimension() == 1); + TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64); + + // We expected a CUDA tensor with two dimensions and shape + // (tokens, hidden_in) for 'a'. + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.ndimension() == 2); + TORCH_CHECK(a.scalar_type() == torch::kBFloat16); + + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(b.scalar_type() == torch::kBFloat16); + TORCH_CHECK(c.scalar_type() == torch::kBFloat16); + + // The expected shapes of 'b' and 'c' are: + // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out) + // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out) + // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden + size_t hidden_in{}, hidden_out{}; + if (trans_a) { + hidden_in = a.size(1); + hidden_out = b.size(1); + + TORCH_CHECK(b.ndimension() == 2); + TORCH_CHECK(c.ndimension() == 3); + TORCH_CHECK(b.size(0) == a.size(0)); + TORCH_CHECK(c.size(0) == batch_sizes.size(0)); + TORCH_CHECK(c.size(1) == hidden_in); + TORCH_CHECK(c.size(2) == hidden_out); + } else { + TORCH_CHECK(b.ndimension() == 3); + TORCH_CHECK(c.ndimension() == 2); + + // Validate the contraction dimensions match. + int64_t tokens = a.size(0), num_experts = b.size(0); + hidden_in = trans_b ? b.size(2) : b.size(1); + hidden_out = trans_b ? b.size(1) : b.size(2); + TORCH_CHECK(hidden_in == a.size(1)); + + // Validate that we have one size per expert. + TORCH_CHECK(batch_sizes.size(0) == num_experts); + } + + // NOTE: We support transposition through the 'trans_b' flag. + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(b.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + + // The `coord_template` argument contains `kDynamicDim` as one of its dimensions + // as a placeholder. This placeholder is later expanded into the actual dimension + // for every element of the batch, either on the host or on the device + // (if we can't do in on the host). + const auto coord_template = trans_a + ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim) + : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in); + if (trans_a) { + CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); + return; + } + if (trans_b) { + CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); + return; + } + CutlassGroupedGemm(a, b, c, batch_sizes, coord_template); + return; +} + +} // namespace grouped_gemm diff --git a/csrc/grouped_gemm_cutlass_sm80.h b/csrc/grouped_gemm_cutlass_sm80.h new file mode 100644 index 0000000..a48dc29 --- /dev/null +++ b/csrc/grouped_gemm_cutlass_sm80.h @@ -0,0 +1,11 @@ +#include + +namespace grouped_gemm { + +void GroupedGemm_CUTLASS_sm80(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b); + +} // namespace grouped_gemm diff --git a/csrc/grouped_gemm_cutlass_sm90.cu b/csrc/grouped_gemm_cutlass_sm90.cu new file mode 100644 index 0000000..4c01c37 --- /dev/null +++ b/csrc/grouped_gemm_cutlass_sm90.cu @@ -0,0 +1,521 @@ +/* + Adapted from the CUTLASS example on fp8 grouped gemm on hopper gpus: + https://github.com/NVIDIA/cutlass/blob/v4.0.0/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +*/ + +#include +#include +#include +#include + +#include + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/layout/matrix.h" + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace grouped_gemm { + +using namespace cute; + +// Per-group GEMM problem shape: (M, N, K) +using ProblemShape = cutlass::gemm::GroupProblemShape>; + +// Element types +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using ElementAccumulator = float; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +// Layout helper: RowMajor vs ColumnMajor depending on transpose flag +template +using GroupedGemmInputLayout = std::conditional_t; +using LayoutD = cutlass::layout::RowMajor; + +// 16-byte alignment in elements +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Arch / opclass +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +// Tile configs +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128,_256,_64>; + using ClusterShape = Shape<_1,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_64,_256,_64>; + using ClusterShape = Shape<_1,_2,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; + using ClusterShape = typename ScheduleConfig::ClusterShape; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + + using LayoutA = GroupedGemmInputLayout; + using LayoutB = GroupedGemmInputLayout; + + // Collective epilogue: D = alpha * Acc + beta * D + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + void, LayoutD*, 1, // Bias (disabled via void) + ElementD, LayoutD*, AlignmentD, // D + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + + // Collective mainloop + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// Concrete kernel types +using Gemm_NN_Cooperative = GemmGivenSchedule::Gemm; +using Gemm_NN_Pingpong = GemmGivenSchedule::Gemm; + +using Gemm_TN_Cooperative = GemmGivenSchedule::Gemm; +using Gemm_TN_Pingpong = GemmGivenSchedule::Gemm; + +using Gemm_NT_Cooperative = GemmGivenSchedule::Gemm; +using Gemm_NT_Pingpong = GemmGivenSchedule::Gemm; + + +// ----------------------------------------------------------------------------- +// Host-side argument preparation +// ----------------------------------------------------------------------------- + +template +struct ArgumentsPreparer { + using StrideA = typename GemmT::GemmKernel::InternalStrideA; + using StrideB = typename GemmT::GemmKernel::InternalStrideB; + using StrideD = typename GemmT::GemmKernel::InternalStrideD; + + // Host metadata + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_D_host; + std::vector problem_sizes_host; + + std::vector ptr_A_host; + std::vector ptr_B_host; + std::vector ptr_D_host; + + int prepare_standard( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& c, + const torch::Tensor& batch_sizes + ) { + TORCH_CHECK(a.dim() == 2, "a must be [total_M, K]"); + TORCH_CHECK(b.dim() == 3, "b must be [groups, K, N] or [groups, N, K]"); + TORCH_CHECK(c.dim() == 2, "c must be [total_M, N]"); + + int64_t groups64 = batch_sizes.size(0); + int groups = static_cast(groups64); + + int64_t total_M = a.size(0); + int64_t K = a.size(1); + int64_t N = c.size(1); + + TORCH_CHECK(c.size(0) == total_M, + "c must have shape [total_M, N]"); + + auto* a_base = reinterpret_cast(a.data_ptr()); + auto* b_base = reinterpret_cast(b.data_ptr()); + auto* c_base = reinterpret_cast(c.data_ptr()); + + int64_t a_offset = 0; + int64_t c_offset = 0; + + stride_A_host.resize(groups); + stride_B_host.resize(groups); + stride_D_host.resize(groups); + problem_sizes_host.resize(groups); + ptr_A_host.resize(groups); + ptr_B_host.resize(groups); + ptr_D_host.resize(groups); + + auto* batch_sizes_ptr = batch_sizes.data_ptr(); + for (int g = 0; g < groups; ++g) { + int64_t M = batch_sizes_ptr[g]; + if (M == 0) { + ptr_A_host[g] = reinterpret_cast(a.data_ptr()); + ptr_B_host[g] = reinterpret_cast(b.data_ptr()); + ptr_D_host[g] = reinterpret_cast(c.data_ptr()); + + stride_A_host[g] = StrideA{}; + stride_B_host[g] = StrideB{}; + stride_D_host[g] = StrideD{}; + + problem_sizes_host[g] = {0, 0, 0}; + continue; + } + + ptr_A_host[g] = a_base + a_offset; + stride_A_host[g] = cutlass::make_cute_packed_stride(StrideA{}, {int(M), int(K), 1}); + + ptr_B_host[g] = b_base + int64_t(g) * K * N; + stride_B_host[g] = cutlass::make_cute_packed_stride(StrideB{}, {int(N), int(K), 1}); + + ptr_D_host[g] = c_base + c_offset; + stride_D_host[g] = cutlass::make_cute_packed_stride(StrideD{}, {int(M), int(N), 1}); + + problem_sizes_host[g] = {int(M), int(N), int(K)}; + + a_offset += M * K; + c_offset += M * N; + } + + return groups; + } + + int prepare_dynamicK( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& c, + const torch::Tensor& batch_sizes + ) { + TORCH_CHECK(a.dim() == 2, "For dynamicK, a must be [total_K, M]"); + TORCH_CHECK(b.dim() == 2, "For dynamicK, b must be [total_K, N]"); + TORCH_CHECK(c.dim() == 3, "For dynamicK, c must be [groups, M, N]"); + + int64_t groups64 = batch_sizes.size(0); + int groups = static_cast(groups64); + + int64_t total_K = a.size(0); + int64_t M = a.size(1); + int64_t N = b.size(1); + + TORCH_CHECK(c.size(0) == groups64 && + c.size(1) == M && + c.size(2) == N, + "c must be [groups, M, N]"); + + auto* a_base = reinterpret_cast(a.data_ptr()); + auto* b_base = reinterpret_cast(b.data_ptr()); + auto* c_base = reinterpret_cast(c.data_ptr()); + + int64_t K_offset = 0; + + stride_A_host.resize(groups); + stride_B_host.resize(groups); + stride_D_host.resize(groups); + problem_sizes_host.resize(groups); + ptr_A_host.resize(groups); + ptr_B_host.resize(groups); + ptr_D_host.resize(groups); + + auto* batch_sizes_ptr = batch_sizes.data_ptr(); + for (int g = 0; g < groups; ++g) { + int64_t K = batch_sizes_ptr[g]; + if (K == 0) { + ptr_A_host[g] = reinterpret_cast(a.data_ptr()); + ptr_B_host[g] = reinterpret_cast(b.data_ptr()); + ptr_D_host[g] = reinterpret_cast(c.data_ptr()); + + stride_A_host[g] = StrideA{}; + stride_B_host[g] = StrideB{}; + stride_D_host[g] = StrideD{}; + + problem_sizes_host[g] = {0, 0, 0}; + continue; + } + + ptr_A_host[g] = a_base + K_offset * M; + stride_A_host[g] = cutlass::make_cute_packed_stride(StrideA{}, {int(M), int(K), 1}); + + ptr_B_host[g] = b_base + K_offset * N; + stride_B_host[g] = cutlass::make_cute_packed_stride(StrideB{}, {int(N), int(K), 1}); + + ptr_D_host[g] = c_base + int64_t(g) * M * N; + stride_D_host[g] = cutlass::make_cute_packed_stride(StrideD{}, {int(M), int(N), 1}); + + problem_sizes_host[g] = {int(M), int(N), int(K)}; + + K_offset += K; + } + + TORCH_CHECK(K_offset == total_K, "Sum(batch_sizes) (", K_offset, ") must equal total_K (", total_K, ")"); + + return groups; + } +}; + +template +void run( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& c, + const torch::Tensor& batch_sizes, + int device_id +) { + GemmT gemm; + ArgumentsPreparer prep; + + // Prepare host metadata + int groups = 0; + if (TransA) { + groups = prep.prepare_dynamicK(a, b, c, batch_sizes); + } else { + groups = prep.prepare_standard(a, b, c, batch_sizes); + } + + // Device allocations for per-group metadata + cutlass::DeviceAllocation problem_sizes_dev; + cutlass::DeviceAllocation ptr_A_dev; + cutlass::DeviceAllocation ptr_B_dev; + cutlass::DeviceAllocation ptr_D_dev; + cutlass::DeviceAllocation stride_A_dev; + cutlass::DeviceAllocation stride_B_dev; + cutlass::DeviceAllocation stride_D_dev; + + problem_sizes_dev.reset(groups); + problem_sizes_dev.copy_from_host(prep.problem_sizes_host.data()); + + ptr_A_dev.reset(groups); + ptr_A_dev.copy_from_host(prep.ptr_A_host.data()); + + ptr_B_dev.reset(groups); + ptr_B_dev.copy_from_host(prep.ptr_B_host.data()); + + ptr_D_dev.reset(groups); + ptr_D_dev.copy_from_host(prep.ptr_D_host.data()); + + stride_A_dev.reset(groups); + stride_A_dev.copy_from_host(prep.stride_A_host.data()); + + stride_B_dev.reset(groups); + stride_B_dev.copy_from_host(prep.stride_B_host.data()); + + stride_D_dev.reset(groups); + stride_D_dev.copy_from_host(prep.stride_D_host.data()); + + // Hardware info + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + // Build arguments + typename GemmT::Arguments arguments{}; + + // Value-init thread epilogue params, then override alpha/beta + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0.0f; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = typename GemmT::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {groups, problem_sizes_dev.get(), prep.problem_sizes_host.data()}, + {ptr_A_dev.get(), stride_A_dev.get(), ptr_B_dev.get(), stride_B_dev.get()}, + {fusion_args, nullptr, nullptr, ptr_D_dev.get(), stride_D_dev.get()}, + kernel_hw_info + }; + + // Workspace & run + size_t workspace_size = GemmT::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); +} + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED + +void GroupedGemm_CUTLASS_sm90( + torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, + bool trans_b, + bool use_pingpong +) { +#if !defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cerr << "CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED not defined. " + "Compile with CUDA 12.3+ and SM90 targets.\n"; + return; +#else + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this function + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "GroupedGemm_CUTLASS_sm90 requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return; + } + + // NOTE: We only support 'trans_a' or 'trans_b', not both. + TORCH_CHECK(!(trans_a && trans_b)); + + // CUTLASS can handle both CPU- and CUDA-resident problem dimensions. + torch::Tensor batch_sizes_cpu = batch_sizes; + if (!batch_sizes.is_cpu()) { + batch_sizes_cpu = batch_sizes.to(torch::kCPU); + } + batch_sizes_cpu = batch_sizes_cpu.contiguous(); + TORCH_CHECK(batch_sizes_cpu.ndimension() == 1); + TORCH_CHECK(batch_sizes_cpu.scalar_type() == torch::kInt64); + + // We expected a CUDA tensor with two dimensions and shape + // (tokens, hidden_in) for 'a'. + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.ndimension() == 2); + TORCH_CHECK(a.scalar_type() == torch::kBFloat16); + + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(b.scalar_type() == torch::kBFloat16); + TORCH_CHECK(c.scalar_type() == torch::kBFloat16); + + // The expected shapes of 'b' and 'c' are: + // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out) + // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out) + // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden + size_t hidden_in{}, hidden_out{}; + if (trans_a) { + hidden_in = a.size(1); + hidden_out = b.size(1); + + TORCH_CHECK(b.ndimension() == 2); + TORCH_CHECK(c.ndimension() == 3); + TORCH_CHECK(b.size(0) == a.size(0)); + TORCH_CHECK(c.size(0) == batch_sizes_cpu.size(0)); + TORCH_CHECK(c.size(1) == hidden_in); + TORCH_CHECK(c.size(2) == hidden_out); + + auto c_view = c.view({batch_sizes_cpu.size(0), (long)hidden_in, (long)hidden_out}); + + for (int64_t g = 0; g < batch_sizes_cpu.size(0); ++g) { + if (batch_sizes_cpu[g].item() == 0) { + c_view[g].zero_(); + } + } + } else { //trans_a == false + TORCH_CHECK(b.ndimension() == 3); + TORCH_CHECK(c.ndimension() == 2); + + // Validate the contraction dimensions match. + int64_t tokens = a.size(0), num_experts = b.size(0); + hidden_in = trans_b ? b.size(2) : b.size(1); + hidden_out = trans_b ? b.size(1) : b.size(2); + TORCH_CHECK(hidden_in == a.size(1)); + + // Validate that we have one size per expert. + TORCH_CHECK(batch_sizes_cpu.size(0) == num_experts); + } + + // NOTE: We support transposition through the 'trans_b' flag. + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(b.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + + // Make sure all inputs are on the same device + int device_id = a.get_device(); + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, device_id)); + if (props.major != 9 || props.minor != 0) { + std::cerr + << "GroupedGemm_CUTLASS_sm90 requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return; + } + TORCH_CHECK(b.get_device() == device_id, "a and b must be on the same device"); + TORCH_CHECK(c.get_device() == device_id, "a and c must be on the same device"); + + // Dispatch on (trans_a, trans_b, use_pingpong) + if ((!use_pingpong) && (!trans_a) && (!trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else if ((use_pingpong) && (!trans_a) && (!trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else if ((!use_pingpong) && (trans_a) && (!trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else if ((use_pingpong) && (trans_a) && (!trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else if ((!use_pingpong) && (!trans_a) && (trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else if ((use_pingpong) && (!trans_a) && (trans_b)) { + run(a, b, c, batch_sizes_cpu, device_id); + } else { + TORCH_CHECK(false, "GEMM option (trans_a, trans_b, use_pingpong) not supported."); + } +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED +return; +} + +} // namespace grouped_gemm diff --git a/csrc/grouped_gemm_cutlass_sm90.h b/csrc/grouped_gemm_cutlass_sm90.h new file mode 100644 index 0000000..444f9dd --- /dev/null +++ b/csrc/grouped_gemm_cutlass_sm90.h @@ -0,0 +1,14 @@ +#include + +namespace grouped_gemm { + +void GroupedGemm_CUTLASS_sm90( + torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, + bool trans_b, + bool use_pingpong); + +} // namespace grouped_gemm diff --git a/csrc/ops.cu b/csrc/ops.cu index 8ea2a25..4848860 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -1,11 +1,17 @@ #include "grouped_gemm.h" +#include "grouped_gemm_cublas.h" +#include "grouped_gemm_cutlass_sm80.h" +#include "grouped_gemm_cutlass_sm90.h" #include namespace grouped_gemm { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gmm", &GroupedGemm, "Grouped GEMM."); + m.def("gmm_base", &GroupedGemm_base, "Grouped GEMM base."); + m.def("gmm_cuBLAS", &GroupedGemm_cuBLAS, "Grouped GEMM cuBLAS."); + m.def("gmm_CUTLASS_sm80", &GroupedGemm_CUTLASS_sm80, "Grouped GEMM CUTLASS for Ampere (sm80)."); + m.def("gmm_CUTLASS_sm90", &GroupedGemm_CUTLASS_sm90, "Grouped GEMM CUTLASS for Hopper (sm90)."); } } // namespace grouped_gemm diff --git a/grouped_gemm/backend.py b/grouped_gemm/backend.py index 32c99c1..afab498 100644 --- a/grouped_gemm/backend.py +++ b/grouped_gemm/backend.py @@ -21,9 +21,32 @@ def _allocate_output(a, b, batch_sizes, trans_a, trans_b): ) return torch.empty(*shape, device=a.device, dtype=a.dtype) -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): +def gmm_base(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): if c is None: c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + backend.gmm_base(a, b, c, batch_sizes, trans_a, trans_b) return c +def gmm_cuBLAS(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm_cuBLAS(a, b, c, batch_sizes, trans_a, trans_b) + return c + +def gmm_CUTLASS_sm80(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm_CUTLASS_sm80(a, b, c, batch_sizes, trans_a, trans_b) + return c + +def gmm_CUTLASS_sm90_cooperative(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm_CUTLASS_sm90(a, b, c, batch_sizes, trans_a, trans_b, False) + return c + +def gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm_CUTLASS_sm90(a, b, c, batch_sizes, trans_a, trans_b, True) + return c \ No newline at end of file diff --git a/grouped_gemm/ops.py b/grouped_gemm/ops.py index 1442a84..d829cb6 100644 --- a/grouped_gemm/ops.py +++ b/grouped_gemm/ops.py @@ -2,32 +2,44 @@ import torch -class GroupedGemm(torch.autograd.Function): - +class GroupedGemmTemplate(torch.autograd.Function): @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): + def forward(ctx, a, b, batch_sizes, trans_b, gemm_op): ctx.save_for_backward(a, b, batch_sizes) ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + ctx.gemm_op = gemm_op + + return gemm_op(a, b, batch_sizes, trans_a=False, trans_b=trans_b) @staticmethod def backward(ctx, grad): grad = grad.contiguous() a, b, batch_sizes = ctx.saved_tensors trans_b = ctx.trans_b + gemm_op = ctx.gemm_op agrad = None if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + agrad = gemm_op(grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) bgrad = None if ctx.needs_input_grad[1]: lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None + bgrad = gemm_op(lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + + return agrad, bgrad, None, None, None + +def gmm_base(a, b, batch_sizes, trans_b=False): + return GroupedGemmTemplate.apply(a, b, batch_sizes, trans_b, backend.gmm_base) + +def gmm_cuBLAS(a, b, batch_sizes, trans_b=False): + return GroupedGemmTemplate.apply(a, b, batch_sizes, trans_b, backend.gmm_cuBLAS) + +def gmm_CUTLASS_sm80(a, b, batch_sizes, trans_b=False): + return GroupedGemmTemplate.apply(a, b, batch_sizes, trans_b, backend.gmm_CUTLASS_sm80) +def gmm_CUTLASS_sm90_cooperative(a, b, batch_sizes, trans_b=False): + return GroupedGemmTemplate.apply(a, b, batch_sizes, trans_b, backend.gmm_CUTLASS_sm90_cooperative) -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) +def gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes, trans_b=False): + return GroupedGemmTemplate.apply(a, b, batch_sizes, trans_b, backend.gmm_CUTLASS_sm90_pingpong) \ No newline at end of file diff --git a/grouped_gemm/ops_test.py b/grouped_gemm/ops_test.py index ceafdd9..ec49327 100644 --- a/grouped_gemm/ops_test.py +++ b/grouped_gemm/ops_test.py @@ -16,16 +16,15 @@ def allclose(x, y, pct=2.0): return False return True - def add_flags(x): out = [] for y in x: for trans_b in (False, True): for batch_sizes_on_device in (False, True): out.append(y + (trans_b, batch_sizes_on_device)) + # out.append(y + (trans_b, False)) return out - _TEST_PROBLEMS = add_flags(( (1, 128, 128, 128), (8, 128, 128, 128), @@ -69,7 +68,7 @@ def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device) a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) - out = ops.gmm(a, b, batch_sizes, trans_b) + out = ops.gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes, trans_b=trans_b) expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) self.assertTrue(allclose(out, expected_out)) @@ -98,7 +97,7 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_devi a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) - out = ops.gmm(a, b, batch_sizes, trans_b) + out = ops.gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes, trans_b=trans_b) expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) self.assertTrue(allclose(out, expected_out)) @@ -109,10 +108,9 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_devi self.assertTrue(allclose(b.grad, b_ref.grad)) -@parameterized.parameters(False, True) class EdgeCasesTest(unittest.TestCase): - def testGroupedGemm_ZeroSize(self, batch_sizes_on_device): + def testGroupedGemm_ZeroSize(self, batch_sizes_on_device=False): torch.manual_seed(0) m = 16384 k = 4096 @@ -130,7 +128,7 @@ def testGroupedGemm_ZeroSize(self, batch_sizes_on_device): a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) - out = ops.gmm(a, b, batch_sizes) + out = ops.gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes) expected_out = gmm(a_ref, b_ref, batch_sizes) self.assertTrue(allclose(out, expected_out)) @@ -140,7 +138,7 @@ def testGroupedGemm_ZeroSize(self, batch_sizes_on_device): self.assertTrue(allclose(a.grad, a_ref.grad)) self.assertTrue(allclose(b.grad, b_ref.grad)) - def testGroupedGemm_ZeroK(self, batch_sizes_on_device): + def testGroupedGemm_ZeroK(self, batch_sizes_on_device=False): sz = 128 total_tokens = 192 @@ -151,7 +149,7 @@ def testGroupedGemm_ZeroK(self, batch_sizes_on_device): if batch_sizes_on_device: batch_sizes = batch_sizes.cuda() - ops.backend.gmm(a, b, batch_sizes, trans_a=True, c=c) + ops.backend.gmm_CUTLASS_sm90_pingpong(a, b, batch_sizes, trans_a=True, c=c) self.assertTrue((c[0] == 0).all()) self.assertTrue((c[1] == 128).all()) self.assertTrue((c[2] == 0).all()) diff --git a/setup.py b/setup.py index dc81c61..3fa0391 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ else: device_capability = torch.cuda.get_device_capability() device_capability = f"{device_capability[0]}{device_capability[1]}" + if device_capability == "90": + device_capability = f"{device_capability}a" cwd = Path(os.path.dirname(os.path.abspath(__file__))) @@ -33,9 +35,16 @@ ext_modules = [ CUDAExtension( "grouped_gemm_backend", - ["csrc/ops.cu", "csrc/grouped_gemm.cu"], + [ + "csrc/ops.cu", + "csrc/grouped_gemm.cu", + "csrc/grouped_gemm_cublas.cu", + "csrc/grouped_gemm_cutlass_sm80.cu", + "csrc/grouped_gemm_cutlass_sm90.cu" + ], include_dirs = [ f"{cwd}/third_party/cutlass/include/", + f"{cwd}/third_party/cutlass/tools/util/include/", f"{cwd}/csrc" ], extra_compile_args={ diff --git a/third_party/cutlass b/third_party/cutlass index 8783c41..b995f93 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 8783c41851cd3582490e04e69e0cd756a8c1db7f +Subproject commit b995f933179c22d3fe0d871c3a53d11e4681950f