Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ list(APPEND transformer_engine_cpp_sources
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm/comm_gemm.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
Expand Down Expand Up @@ -280,11 +281,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
endif()
endforeach()

if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()

add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
# Disable CMake's automatic architecture flag injection.
# All architectures are handled explicitly via per-source COMPILE_OPTIONS
Expand Down
47 changes: 46 additions & 1 deletion transformer_engine/common/comm_gemm/comm_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "transformer_engine/comm_gemm.h"

#include <cublasmp.h>
#include <cuda_runtime.h>

#include <map>
Expand All @@ -21,6 +20,10 @@
#include "../common.h"
#include "../util/logging.h"

#ifdef NVTE_WITH_CUBLASMP

#include <cublasmp.h>

using namespace transformer_engine;

namespace {
Expand Down Expand Up @@ -530,3 +533,45 @@ int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) {
NVTE_API_CALL(nvte_comm_gemm_numroc);
return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks);
}

#else // NVTE_WITH_CUBLASMP

struct NVTECommGemmCtx {};

NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k,
const NVTETensor a, const NVTETensor b, const NVTETensor d,
const NVTETensor bias, const NVTETensor pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream, NVTECommGemmAlgoType algo) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) {
NVTE_ERROR("Transformer Engine has not been built with cuBLASMp support.");
}

#endif // NVTE_WITH_CUBLASMP
Loading