Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
build/
**.so
*.hip
*_hip.*
*_hip.*
.idea/
dist/
167 changes: 43 additions & 124 deletions csrc/causal_conv1d_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
#include "causal_conv1d_common.h"
#include "static_switch.h"

// Helper function to set the maximum dynamic shared memory attribute for the backward kernel.
// This function is defined at file scope so that no preprocessor directive appears inside a lambda.
template <typename KernelT>
void set_max_dynamic_shared_memory_bwd(KernelT kernel, int smem_size) {
if (smem_size >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute((void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.\n" << std::endl;
#endif
}
}

template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
struct Causal_conv1d_bwd_kernel_traits {
using input_t = input_t_;
Expand All @@ -30,8 +44,8 @@ struct Causal_conv1d_bwd_kernel_traits {
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
static_assert(kWidth <= kNElts);
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
// Its possible that we need to do 2 rounds of exchange if input_t is 16 bits
// (since then wed have 8 values of float, and each round we can exchange 4 floats).
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
Expand Down Expand Up @@ -125,7 +139,7 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
if constexpr (!kSiluAct) {
__syncthreads();
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
// Thread 0 doesn't write yet so that thread kNThreads - 1 can read
// the first elements of the next chunk.
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
__syncthreads();
Expand Down Expand Up @@ -167,11 +181,8 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
* (1.0f + out_val * (1.0f - out_sigmoid_val));
}
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
// if input_t is 16 bits (since then we'd have 8 values of float)
// Exchange the dout_vals. (May require multiple rounds for 16-bit input_t.)
__syncthreads();
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
// the first elements of the next chunk.
if (tidx > 0) {
#pragma unroll
for (int r = 0; r < kNExchangeRounds; ++r) {
Expand All @@ -185,7 +196,6 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
}
__syncthreads();
// Now thread 0 can write the first elements of the current chunk.
if (tidx == 0) {
#pragma unroll
for (int r = 0; r < kNExchangeRounds; ++r) {
Expand Down Expand Up @@ -254,18 +264,8 @@ void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
dim3 grid(params.batch, params.dim);
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
}

// Use the helper function to set the dynamic shared memory attribute.
set_max_dynamic_shared_memory_bwd(kernel, kSmemSize);

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -311,11 +311,6 @@ struct Causal_conv1d_channellast_bwd_kernel_traits {
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
// sizeof(typename BlockStoreT::TempStorage)});
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
};

template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
Expand Down Expand Up @@ -358,10 +353,14 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
: reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
// Always declare dinitial_states (set to nullptr if not used)
input_t *dinitial_states = nullptr;
if (params.dinitial_states_ptr != nullptr && chunk_l_id == 0) {
dinitial_states = reinterpret_cast<input_t *>(params.dinitial_states_ptr)
+ batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
}
input_t *dfinal_states = params.dfinal_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + l_idx * params.dfinal_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;

#pragma unroll
for (int l = 0; l < Ktraits::kNLoads; ++l) {
Expand Down Expand Up @@ -395,7 +394,7 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
}
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
// Need to load (kWidth - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs.
if constexpr (kSiluAct) {
if (l_idx < kWidth - 1) {
input_t x_vals_load[kNElts] = {0};
Expand All @@ -413,7 +412,6 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
Expand All @@ -430,119 +428,41 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
}
}
float dout_vals[kLPerThread + kWidth - 1];
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
float x_vals[kWidth - 1 + kLPerThread];
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
}

int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
int seq_idx_thread[kWidth - 1 + kLPerThread];
if constexpr (kHasSeqIdx) {
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
}
}

if constexpr (kSiluAct) { // Recompute the output
#pragma unroll
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
}
#pragma unroll
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
float out_val = bias_val;
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
if constexpr (!kHasSeqIdx) {
out_val += weight_vals[w] * x_vals[i + w];
} else {
out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
}
}
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
}
}

float dweight_vals[kWidth] = {0};
SumOp<float> sum_op;
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) {
if constexpr (!kHasSeqIdx) {
dweight_vals[w] += x_vals[i + w] * dout_vals[i];
} else {
dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
}
}
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
}
}

if (params.bias_ptr != nullptr) {
float dbias_val = 0.f;
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
}
}

float dx_vals[kLPerThread] = {0};
float out_vals[kLPerThread];
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) {
out_vals[i] = bias_val;
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
if constexpr (!kHasSeqIdx) {
dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
out_vals[i] += weight_vals[w] * x_vals[i + w];
} else {
dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
}
}
// if (dfinal_states != nullptr) {
if constexpr (kHasDfinalStates) {
if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
&& chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
&& chunk_c_id * kChunkSizeC + row_idx < params.dim) {
dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
}
}
}

float dxinit_vals[kWidth - 1] = {0};
static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
if (dinitial_states != nullptr && col_idx == 0) {
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) {
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
}
// chunk_l_id must be 0 because dinitial_states != nullptr
// if (dfinal_states != nullptr) {
if constexpr (kHasDfinalStates) {
if (i >= params.seqlen) {
dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
}
}
if (params.silu_activation) {
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
}
}

__syncthreads();
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
if (dinitial_states != nullptr && col_idx == 0) {
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
for (int i = 0; i < kLPerThread; ++i) {
x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i];
}
__syncthreads();

Expand All @@ -562,7 +482,6 @@ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
*reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
}

}

template<int kNThreads, int kWidth, typename input_t, typename weight_t>
Expand All @@ -571,7 +490,7 @@ void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t st
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger.
static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
// constexpr int kSmemSize = Ktraits::kSmemSize;
Expand Down Expand Up @@ -624,4 +543,4 @@ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsB
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
Loading