-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
34 lines (32 loc) · 1012 Bytes
/
setup.py
File metadata and controls
34 lines (32 loc) · 1012 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
# Check CUDA availability
if not torch.cuda.is_available():
print("Warning: CUDA not available. Building CPU-only version.")
ext_modules = []
else:
ext_modules = [
CUDAExtension(
name='fastmqa_cuda',
sources=['kernels/mqa_extension.cpp', 'kernels/mqa_kernel.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_75,code=sm_75',
'-gencode=arch=compute_80,code=sm_80',
'-gencode=arch=compute_86,code=sm_86',
]
}
)
]
setup(
name='fastmqa',
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
install_requires=['torch>=2.0.0'],
)