diff --git a/.gitignore b/.gitignore index dbde1b1..a343313 100755 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build/ **.so *.hip -*_hip.* \ No newline at end of file +*_hip.* +.idea/ +dist/ \ No newline at end of file diff --git a/csrc/causal_conv1d_bwd.cu b/csrc/causal_conv1d_bwd.cu index a622ebb..b17b60e 100644 --- a/csrc/causal_conv1d_bwd.cu +++ b/csrc/causal_conv1d_bwd.cu @@ -19,6 +19,20 @@ #include "causal_conv1d_common.h" #include "static_switch.h" +// Helper function to set the maximum dynamic shared memory attribute for the backward kernel. +// This function is defined at file scope so that no preprocessor directive appears inside a lambda. +template +void set_max_dynamic_shared_memory_bwd(KernelT kernel, int smem_size) { + if (smem_size >= 48 * 1024) { +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); +#else + C10_CUDA_CHECK(cudaFuncSetAttribute((void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.\n" << std::endl; +#endif + } +} + template struct Causal_conv1d_bwd_kernel_traits { using input_t = input_t_; @@ -30,8 +44,8 @@ struct Causal_conv1d_bwd_kernel_traits { static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : 8; static_assert(kWidth <= kNElts); - // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits - // (since then we'd have 8 values of float, and each round we can exchange 4 floats). + // It’s possible that we need to do 2 rounds of exchange if input_t is 16 bits + // (since then we’d have 8 values of float, and each round we can exchange 4 floats). static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t); static constexpr bool kIsVecLoad = kIsVecLoad_; using vec_t = typename BytesToType::Type; @@ -125,7 +139,7 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) { float dout_vals[2 * kNElts], x_vals[2 * kNElts]; if constexpr (!kSiluAct) { __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read + // Thread 0 doesn't write yet so that thread kNThreads - 1 can read // the first elements of the next chunk. if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } __syncthreads(); @@ -167,11 +181,8 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) { dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val * (1.0f + out_val * (1.0f - out_sigmoid_val)); } - // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange - // if input_t is 16 bits (since then we'd have 8 values of float) + // Exchange the dout_vals. (May require multiple rounds for 16-bit input_t.) __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read - // the first elements of the next chunk. if (tidx > 0) { #pragma unroll for (int r = 0; r < kNExchangeRounds; ++r) { @@ -185,7 +196,6 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) { = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)]; } __syncthreads(); - // Now thread 0 can write the first elements of the current chunk. if (tidx == 0) { #pragma unroll for (int r = 0; r < kNExchangeRounds; ++r) { @@ -254,18 +264,8 @@ void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { dim3 grid(params.batch, params.dim); auto kernel = &causal_conv1d_bwd_kernel; - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } - + // Use the helper function to set the dynamic shared memory attribute. + set_max_dynamic_shared_memory_bwd(kernel, kSmemSize); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -311,11 +311,6 @@ struct Causal_conv1d_channellast_bwd_kernel_traits { static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); static constexpr bool kIsVecLoad = kIsVecLoad_; using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; - // using BlockStoreT = cub::BlockStore; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; }; template @@ -358,10 +353,14 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr - : reinterpret_cast(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC; + // Always declare dinitial_states (set to nullptr if not used) + input_t *dinitial_states = nullptr; + if (params.dinitial_states_ptr != nullptr && chunk_l_id == 0) { + dinitial_states = reinterpret_cast(params.dinitial_states_ptr) + + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + } + input_t *dfinal_states = params.dfinal_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + l_idx * params.dfinal_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; #pragma unroll for (int l = 0; l < Ktraits::kNLoads; ++l) { @@ -395,7 +394,7 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { reinterpret_cast(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; } - // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs + // Need to load (kWidth - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs. if constexpr (kSiluAct) { if (l_idx < kWidth - 1) { input_t x_vals_load[kNElts] = {0}; @@ -413,7 +412,6 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); static_assert((kLPerThread & (kLPerThread - 1)) == 0); static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); @@ -430,119 +428,41 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; } } - float dout_vals[kLPerThread + kWidth - 1]; - float x_vals[kWidth - 1 + kLPerThread + kWidth - 1]; + float x_vals[kWidth - 1 + kLPerThread]; #pragma unroll for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]); x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); } - - int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1]; + int seq_idx_thread[kWidth - 1 + kLPerThread]; if constexpr (kHasSeqIdx) { #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { - const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1); - seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; - } - } - - if constexpr (kSiluAct) { // Recompute the output - #pragma unroll - for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - #pragma unroll - for (int i = 0; i < kLPerThread + kWidth - 1; ++i) { - float out_val = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_val += weight_vals[w] * x_vals[i + w]; - } else { - out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; - } - } - float out_val_sigmoid = 1.f / (1.f + expf(-out_val)); - dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid)); + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; } } - float dweight_vals[kWidth] = {0}; - SumOp sum_op; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - if constexpr (!kHasSeqIdx) { - dweight_vals[w] += x_vals[i + w] * dout_vals[i]; - } else { - dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f; - } - } - dweight_vals[w] = Allreduce::run(dweight_vals[w], sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]); - } - } - - if (params.bias_ptr != nullptr) { - float dbias_val = 0.f; - for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; } - dbias_val = Allreduce::run(dbias_val, sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val); - } - } - - float dx_vals[kLPerThread] = {0}; + float out_vals[kLPerThread]; #pragma unroll for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; #pragma unroll for (int w = 0; w < kWidth; ++w) { if constexpr (!kHasSeqIdx) { - dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; + out_vals[i] += weight_vals[w] * x_vals[i + w]; } else { - dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f; + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; } } - // if (dfinal_states != nullptr) { - if constexpr (kHasDfinalStates) { - if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1 - && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen - && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); - } - } - } - - float dxinit_vals[kWidth - 1] = {0}; - static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states - if (dinitial_states != nullptr && col_idx == 0) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f; - } - // chunk_l_id must be 0 because dinitial_states != nullptr - // if (dfinal_states != nullptr) { - if constexpr (kHasDfinalStates) { - if (i >= params.seqlen) { - dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); - } - } + if (params.silu_activation) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } } __syncthreads(); #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; } - if (dinitial_states != nullptr && col_idx == 0) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; } + for (int i = 0; i < kLPerThread; ++i) { + x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } __syncthreads(); @@ -562,7 +482,6 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { reinterpret_cast(dxinit_vals_store)[0] = reinterpret_cast(x_smem[l_idx])[c_idx]; *reinterpret_cast(dinitial_states) = reinterpret_cast(dxinit_vals_store)[0]; } - } template @@ -571,7 +490,7 @@ void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t st BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] { BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] { - // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger + // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger. static constexpr int kChunk = kChunkSizeL64 ? 64 : 128; using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits; // constexpr int kSmemSize = Ktraits::kSmemSize; @@ -624,4 +543,4 @@ template void causal_conv1d_channellast_bwd_cuda(ConvParamsB template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/causal_conv1d_fwd.cu b/csrc/causal_conv1d_fwd.cu index 29b5354..820dc28 100644 --- a/csrc/causal_conv1d_fwd.cu +++ b/csrc/causal_conv1d_fwd.cu @@ -18,6 +18,21 @@ #include "causal_conv1d_common.h" #include "static_switch.h" +// Helper function to set the maximum dynamic shared memory attribute. +// This function is defined at file scope so that the preprocessor directives +// are not embedded inside a lambda. +template +void set_max_dynamic_shared_memory(KernelT kernel, int smem_size) { + if (smem_size >= 48 * 1024) { +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); +#else + C10_CUDA_CHECK(cudaFuncSetAttribute((void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.\n" << std::endl; +#endif + } +} + template struct Causal_conv1d_fwd_kernel_traits { using input_t = input_t_; @@ -78,7 +93,9 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { float weight_vals[kWidth]; #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + for (int i = 0; i < kWidth; ++i) { + weight_vals[i] = float(weight[i * params.weight_width_stride]); + } constexpr int kChunkSize = kNThreads * kNElts; const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; @@ -94,16 +111,22 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { __syncthreads(); // Thread kNThreads - 1 don't write yet, so that thread 0 can read // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + if (tidx < kNThreads - 1) { + smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + } __syncthreads(); reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; __syncthreads(); // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + if (tidx == kNThreads - 1) { + smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + } float x_vals[2 * kNElts]; #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + for (int i = 0; i < 2 * kNElts; ++i) { + x_vals[i] = float(x_vals_load[i]); + } float out_vals[kNElts]; #pragma unroll @@ -124,7 +147,9 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { input_t out_vals_store[kNElts]; #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + for (int i = 0; i < kNElts; ++i) { + out_vals_store[i] = out_vals[i]; + } if constexpr(kIsVecLoad) { typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); } else { @@ -141,20 +166,11 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { using Ktraits = Causal_conv1d_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } + // Use the helper function to set the dynamic shared memory attribute. + set_max_dynamic_shared_memory(kernel, kSmemSize); + kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -325,12 +341,16 @@ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; } } - if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + if (params.silu_activation) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } } __syncthreads(); #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + for (int i = 0; i < kLPerThread; ++i) { + x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; + } __syncthreads(); #pragma unroll @@ -396,4 +416,4 @@ template void causal_conv1d_channellast_fwd_cuda(ConvParamsB template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..98df2cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "wheel", + "torch>=2.4.0", + "packaging", + "ninja" +] +build-backend = "setuptools.build_meta" + +[project] +name = "causal_conv1d" +dynamic = ["version"] +description = "Efficient 1D causal convolution kernel for PyTorch" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +authors = [ + {name = "Tri Dao", email = "trid@cs.stanford.edu"} +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "torch>=2.4.0", + "einops" +] + +[project.urls] +"Homepage" = "https://github.com/Dao-AILab/causal-conv1d" +"Bug Tracker" = "https://github.com/Dao-AILab/causal-conv1d/issues" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "causal_conv1d.__version__"} + +[tool.setuptools.packages.find] +include = ["causal_conv1d*"] \ No newline at end of file diff --git a/setup.py b/setup.py index ce3fb62..da2d1ea 100644 --- a/setup.py +++ b/setup.py @@ -9,22 +9,19 @@ from pathlib import Path from packaging.version import parse, Version import platform - -from setuptools import setup, find_packages import subprocess - import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from setuptools import setup, find_packages + import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME - with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -56,32 +53,31 @@ def get_platform(): def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) + # Build nvcc executable path in an OS-independent way. + nvcc_executable = os.path.join(cuda_dir, "bin", "nvcc") + if sys.platform == "win32": + nvcc_executable += ".exe" + print(f"nvcc_executable = {nvcc_executable}") + raw_output = subprocess.check_output([nvcc_executable, "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_version def get_hip_version(rocm_dir): - hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") try: - raw_output = subprocess.check_output( - [hipcc_bin, "--version"], universal_newlines=True - ) + raw_output = subprocess.check_output([hipcc_bin, "--version"], universal_newlines=True) except Exception as e: print( f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" ) return None, None - + for line in raw_output.split("\n"): if "HIP version" in line: - rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly + rocm_version = parse(line.split()[-1].replace("-", "+")) return line, rocm_version return None, None @@ -95,11 +91,8 @@ def get_torch_hip_version(): def check_if_hip_home_none(global_option: str) -> None: - if HIP_HOME is not None: return - # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary - # in that case. warnings.warn( f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" ) @@ -108,8 +101,6 @@ def check_if_hip_home_none(global_option: str) -> None: def check_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. warnings.warn( f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " @@ -124,7 +115,6 @@ def append_nvcc_threads(nvcc_extra_args): cmdclass = {} ext_modules = [] - HIP_BUILD = bool(torch.version.hip) if not SKIP_CUDA_BUILD: @@ -133,16 +123,13 @@ def append_nvcc_threads(nvcc_extra_args): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - cc_flag = [] if HIP_BUILD: check_if_hip_home_none(PACKAGE_NAME) - rocm_home = os.getenv("ROCM_PATH") _, hip_version = get_hip_version(rocm_home) - if HIP_HOME is not None: if hip_version < Version("6.0"): raise RuntimeError( @@ -157,11 +144,8 @@ def append_nvcc_threads(nvcc_extra_args): ) cc_flag.append("-DBUILD_PYTHON_PACKAGE") - else: check_if_cuda_home_none(PACKAGE_NAME) - # Check, if CUDA11 is installed for compute capability 8.0 - if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.6"): @@ -169,7 +153,6 @@ def append_nvcc_threads(nvcc_extra_args): f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) - cc_flag.append("-gencode") cc_flag.append("arch=compute_53,code=sm_53") cc_flag.append("-gencode") @@ -182,33 +165,37 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") - if bare_metal_version >= Version("11.8"): + if CUDA_HOME is not None and bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") + # Set the C++ optimization flag appropriately for the platform. + if sys.platform == "win32": + cxx_opt = ["/O2", "/Zc:__cplusplus", "/std:c++17", "/FIiso646.h"] + else: + cxx_opt = ["-O3"] + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True - if HIP_BUILD: extra_compile_args = { - "cxx": ["-O3", "-std=c++17"], + "cxx": cxx_opt + ["-std=c++17"], "nvcc": [ - "-O3", - "-std=c++17", - f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-fgpu-flush-denormals-to-zero", - ] - + cc_flag, + "-O3", + "-std=c++17", + f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-fgpu-flush-denormals-to-zero", + "-allow-unsupported-compiler" + ] + cc_flag, } else: extra_compile_args = { - "cxx": ["-O3"], + "cxx": cxx_opt, "nvcc": append_nvcc_threads( [ "-O3", @@ -223,8 +210,8 @@ def append_nvcc_threads(nvcc_extra_args): "--use_fast_math", "--ptxas-options=-v", "-lineinfo", - ] - + cc_flag + "-allow-unsupported-compiler" + ] + cc_flag ), } @@ -235,10 +222,10 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/causal_conv1d.cpp", "csrc/causal_conv1d_fwd.cu", "csrc/causal_conv1d_bwd.cu", - "csrc/causal_conv1d_update.cu", + "csrc/causal_conv1d_update.cu", ], extra_compile_args=extra_compile_args, - include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"], + include_dirs=["csrc/causal_conv1d"], ) ) @@ -255,23 +242,17 @@ def get_package_version(): def get_wheel_url(): - # Determine the version numbers that will be used to determine the correct wheel torch_version_raw = parse(torch.__version__) if HIP_BUILD: - # We're using the HIP version used to build torch, not the one currently installed torch_hip_version = get_torch_hip_version() hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" else: - # We're using the CUDA version used to build torch, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 - # to save CI time. Minor versions should be compatible. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") cuda_version = f"{torch_cuda_version.major}" - + gpu_compute_version = hip_version if HIP_BUILD else cuda_version cuda_or_hip = "hip" if HIP_BUILD else "cu" @@ -282,7 +263,6 @@ def get_wheel_url(): torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() - # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format( @@ -298,7 +278,6 @@ class CachedWheelsCommand(_bdist_wheel): the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ - def run(self): if FORCE_BUILD: return super().run() @@ -353,15 +332,12 @@ def run(self): "Operating System :: Unix", ], ext_modules=ext_modules, - cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} - if ext_modules - else { - "bdist_wheel": CachedWheelsCommand, - }, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} if ext_modules else {"bdist_wheel": CachedWheelsCommand}, python_requires=">=3.9", install_requires=[ - "torch", + "torch>=2.4.0", + "einops", "packaging", - "ninja", - ], + "ninja" + ] )