From 3c89426637254a098b27b8bc3094b65c2fa66a43 Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 5 May 2026 21:59:28 -0700 Subject: [PATCH] [Common] Always define cuBLASMp comm GEMM API (#2963) * Always define cuBLASMp comm GEMM API Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 6 +-- .../common/comm_gemm/comm_gemm.cpp | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 781fe48814..734941595d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -159,6 +159,7 @@ list(APPEND transformer_engine_cpp_sources util/cuda_runtime.cpp util/multi_stream.cpp util/rtc.cpp + comm_gemm/comm_gemm.cpp comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/comm_gemm_overlap.cpp @@ -280,11 +281,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) endif() endforeach() -if (NVTE_WITH_CUBLASMP) -list(APPEND transformer_engine_SOURCES - comm_gemm/comm_gemm.cpp) -endif() - add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) # Disable CMake's automatic architecture flag injection. # All architectures are handled explicitly via per-source COMPILE_OPTIONS diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index a7d78f7ac0..ce389c2006 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -6,7 +6,6 @@ #include "transformer_engine/comm_gemm.h" -#include #include #include @@ -21,6 +20,10 @@ #include "../common.h" #include "../util/logging.h" +#ifdef NVTE_WITH_CUBLASMP + +#include + using namespace transformer_engine; namespace { @@ -530,3 +533,45 @@ int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { NVTE_API_CALL(nvte_comm_gemm_numroc); return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks); } + +#else // NVTE_WITH_CUBLASMP + +struct NVTECommGemmCtx {}; + +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { + NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support."); +} + +#endif // NVTE_WITH_CUBLASMP