diff --git a/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp b/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp index cbf18e0..de36507 100644 --- a/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp +++ b/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp @@ -49,7 +49,6 @@ struct PairEq { class BlasCuda { cublasLtHandle_t ltHandle = nullptr; - cublasHandle_t handle = nullptr; cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatmulPreference_t preference = nullptr; void *d_workspace = nullptr; @@ -72,7 +71,6 @@ class BlasCuda { BlasCuda(alpaka::QueueCudaRtNonBlocking &queue) : m_queue{queue} { stream = static_cast(m_queue.getNativeHandle()); CHECK_CUBLAS(cublasLtCreate(<Handle)); - CHECK_CUBLAS(cublasCreate(&handle)); heuristic = {}; CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); @@ -118,10 +116,10 @@ class BlasCuda { } } - void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k) { - CheckAndAddLayout(k, m); - CheckAndAddLayout(k, n); - CheckAndAddLayout(m, n); + void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k, std::size_t lda, std::size_t ldb, std::size_t ldc) { + CheckAndAddLayout(m, k, lda); + CheckAndAddLayout(k, n, ldb); + CheckAndAddLayout(m, n, ldc); } template @@ -171,7 +169,6 @@ gemm(char transa, char transb, const unsigned int m, 1, &localHeuristic, &returnedResults)); - if (returnedResults == 0) { cublasLtMatmulDescDestroy(localDesc); std::cerr << "No suitable cuBLASLt algorithm found!\n"; @@ -238,7 +235,8 @@ gemmrelu(char transa, char transb, const unsigned int m, 1, &localHeuristic, &error_flag)); - + std::cout << "Requested workspace: " + << localHeuristic.workspaceSize << std::endl; if (error_flag == 0) { cublasLtMatmulDescDestroy(localDesc); std::cerr << "No suitable cuBLASLt algorithm found!\n"; @@ -310,14 +308,91 @@ gemmrelu(char transa, char transb, const unsigned int m, workspaceSize, stream)); } + + // matmul without bias + template + inline void + matmul(char transa, char transb, const unsigned int m, + const unsigned int n, const unsigned int k, + const float alpha, + alpaka::BufCudaRt, TIdx> const &A, + alpaka::BufCudaRt, TIdx> const &B, + const float beta, + alpaka::BufCudaRt, TIdx> &C) + { + + matmul(transa, transb, m, n, k, alpha, + alpaka::getPtrNative(A), + alpaka::getPtrNative(B), + beta, + alpaka::getPtrNative(C)); + } + + inline void + matmul(char transa, char transb, const unsigned int m, + const unsigned int n, const unsigned int k, + const float alpha, + float * A, + float * B, + const float beta, + float * C) + { + cublasLtMatmulDesc_t localDesc = nullptr; + CHECK_CUBLAS(cublasLtMatmulDescCreate(&localDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + cublasOperation_t transB_op = charToCuBlasTranspose(transb); + CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + localDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB_op, sizeof(transB_op))); + + cublasOperation_t transA_op = charToCuBlasTranspose(transa); + CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + localDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA_op, sizeof(transA_op))); + + + cublasLtMatmulHeuristicResult_t localHeuristic{}; + int returnedResults = 0; + CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + localDesc, + LayoutStore.at({m, k}), + LayoutStore.at({k, n}), + LayoutStore.at({m, n}), + LayoutStore.at({m, n}), + preference, + 1, + &localHeuristic, + &returnedResults)); + if (returnedResults == 0) { + cublasLtMatmulDescDestroy(localDesc); + std::cerr << "No suitable cuBLASLt algorithm found!\n"; + exit(EXIT_FAILURE); + } + + CHECK_CUBLAS(cublasLtMatmul( + ltHandle, + localDesc, + &alpha, + A, LayoutStore.at({m, k}), + B, LayoutStore.at({k, n}), + &beta, + C, LayoutStore.at({m, n}), + C, LayoutStore.at({m, n}), + &(localHeuristic.algo), + d_workspace, + workspaceSize, + stream)); + + cudaDeviceSynchronize(); + CHECK_CUBLAS(cublasLtMatmulDescDestroy(localDesc)); + } + private: alpaka::QueueCudaRtNonBlocking m_queue; - void CheckAndAddLayout(size_t rows, size_t cols) { + void CheckAndAddLayout(size_t rows, size_t cols, size_t ld) { auto key = std::make_pair(rows, cols); if (LayoutStore.find(key) == LayoutStore.end()) { cublasLtMatrixLayout_t temp = nullptr; - size_t ld = rows; CHECK_CUBLAS( cublasLtMatrixLayoutCreate(&temp, CUDA_R_32F, rows, cols, ld)); LayoutStore.emplace(key, temp);