I'm performing static analysis on CUDA programs and have identified several potential integer overflow issues in rpe_index.cu.
at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index) {
...
const index_t B = input.size(0);
const index_t H = input.size(1);
const index_t num_buckets = input.size(3);
const index_t L_query = index.size(0);
const index_t L_key = index.size(1);
const index_t L_qk = L_query * L_key;
at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options());
const index_t numel = Y.numel();
}
Based on the model code, the input tensor has shape [B, H, L, D], and the index tensor has shape [L, L]. L is embed_dim which is 768 by default. D is emded_dim // num_heads which is 64 by default.
When index_t is a 32-bit integer (int32_t), the following assignments may result in truncation from 64-bit to 32-bit, risking integer overflow for large batch or head sizes:
- index_t B = input.size(0)
- index_t H = input.size(1)
- index_t numel = Y.numel()
__global__ void rpe_index_forward_gpu_kernel(
...
const index_t ind = bi * s0 + hi * s1 + qi * s2 + p_index[i % L_qk] * s3;
Similarly, the last addition operation may also overflow.
In addition, if embed_dim and num_heads are set to large values instead of their default values, computations such as index_t L_query = index.size(0), index_t L_qk = L_query * L_key, bi * s0 are also vulnerable.
I'm performing static analysis on CUDA programs and have identified several potential integer overflow issues in rpe_index.cu.
Based on the model code, the input tensor has shape [B, H, L, D], and the index tensor has shape [L, L]. L is embed_dim which is 768 by default. D is emded_dim // num_heads which is 64 by default.
When index_t is a 32-bit integer (int32_t), the following assignments may result in truncation from 64-bit to 32-bit, risking integer overflow for large batch or head sizes:
Similarly, the last addition operation may also overflow.
In addition, if embed_dim and num_heads are set to large values instead of their default values, computations such as
index_t L_query = index.size(0),index_t L_qk = L_query * L_key,bi * s0are also vulnerable.