From da3c2f4bb37d527ad00965fe5094b39af3617e02 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sat, 11 Oct 2025 15:44:24 +0800 Subject: [PATCH] Multi-backend support Signed-off-by: Xiaodong Ye --- .gitignore | 8 + .../benchmark_fast_hadamard_transform.py | 2 + csrc/fast_hadamard_transform.cpp | 26 +++ csrc/fast_hadamard_transform_common.h | 3 +- ...cuda.cu => fast_hadamard_transform_gpu.cu} | 29 ++- csrc/vendor.h | 37 ++++ setup.py | 203 ++++++++++++------ tests/test_fast_hadamard_transform.py | 2 + 8 files changed, 234 insertions(+), 76 deletions(-) create mode 100644 .gitignore rename csrc/{fast_hadamard_transform_cuda.cu => fast_hadamard_transform_gpu.cu} (96%) create mode 100644 csrc/vendor.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..066b428 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +build/ +dist/ +*.egg-info/ + +# hipified files +*.hip +*_hip.* diff --git a/benchmarks/benchmark_fast_hadamard_transform.py b/benchmarks/benchmark_fast_hadamard_transform.py index 15f2597..b2cbc54 100644 --- a/benchmarks/benchmark_fast_hadamard_transform.py +++ b/benchmarks/benchmark_fast_hadamard_transform.py @@ -14,6 +14,8 @@ dim = 16384 * 2 dtype = torch.float16 device = "cuda" +if hasattr(torch, "musa"): + device = "musa" torch.random.manual_seed(0) x = torch.randn(batch_size, seqlen, dim, dtype=dtype, device=device) diff --git a/csrc/fast_hadamard_transform.cpp b/csrc/fast_hadamard_transform.cpp index 512df8b..fbcec7d 100644 --- a/csrc/fast_hadamard_transform.cpp +++ b/csrc/fast_hadamard_transform.cpp @@ -2,11 +2,17 @@ * Copyright (c) 2023, Tri Dao. ******************************************************************************/ +#ifndef USE_MUSA #include #include +#else +#include "torch_musa/csrc/aten/musa/MUSAContext.h" +#include "torch_musa/csrc/core/MUSAGuard.h" +#endif #include #include +#include "vendor.h" #include "fast_hadamard_transform.h" #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -74,7 +80,11 @@ fast_hadamard_transform(at::Tensor &x, float scale) { auto input_type = x.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); +#ifndef USE_MUSA TORCH_CHECK(x.is_cuda()); +#else + TORCH_CHECK(x.is_privateuseone()); +#endif const auto shapes_og = x.sizes(); const int dim_og = x.size(-1); @@ -117,7 +127,11 @@ fast_hadamard_transform_12N(at::Tensor &x, float scale) { auto input_type = x.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); +#ifndef USE_MUSA TORCH_CHECK(x.is_cuda()); +#else + TORCH_CHECK(x.is_privateuseone()); +#endif const auto shapes_og = x.sizes(); const int dim_og = x.size(-1); @@ -160,7 +174,11 @@ fast_hadamard_transform_20N(at::Tensor &x, float scale) { auto input_type = x.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); +#ifndef USE_MUSA TORCH_CHECK(x.is_cuda()); +#else + TORCH_CHECK(x.is_privateuseone()); +#endif const auto shapes_og = x.sizes(); const int dim_og = x.size(-1); @@ -203,7 +221,11 @@ fast_hadamard_transform_28N(at::Tensor &x, float scale) { auto input_type = x.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); +#ifndef USE_MUSA TORCH_CHECK(x.is_cuda()); +#else + TORCH_CHECK(x.is_privateuseone()); +#endif const auto shapes_og = x.sizes(); const int dim_og = x.size(-1); @@ -246,7 +268,11 @@ fast_hadamard_transform_40N(at::Tensor &x, float scale) { auto input_type = x.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); +#ifndef USE_MUSA TORCH_CHECK(x.is_cuda()); +#else + TORCH_CHECK(x.is_privateuseone()); +#endif const auto shapes_og = x.sizes(); const int dim_og = x.size(-1); diff --git a/csrc/fast_hadamard_transform_common.h b/csrc/fast_hadamard_transform_common.h index f165320..ff92698 100644 --- a/csrc/fast_hadamard_transform_common.h +++ b/csrc/fast_hadamard_transform_common.h @@ -4,8 +4,7 @@ #pragma once -#include -#include +#include "vendor.h" #define FULL_MASK 0xffffffff diff --git a/csrc/fast_hadamard_transform_cuda.cu b/csrc/fast_hadamard_transform_gpu.cu similarity index 96% rename from csrc/fast_hadamard_transform_cuda.cu rename to csrc/fast_hadamard_transform_gpu.cu index ff64757..88df125 100644 --- a/csrc/fast_hadamard_transform_cuda.cu +++ b/csrc/fast_hadamard_transform_gpu.cu @@ -6,8 +6,13 @@ #include #include +#ifndef USE_MUSA #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#else +#include "torch_musa/csrc/core/MUSAException.h" // For C10_MUSA_CHECK and C10_MUSA_KERNEL_LAUNCH_CHECK +#endif +#include "vendor.h" #include "fast_hadamard_transform.h" #include "fast_hadamard_transform_common.h" #include "fast_hadamard_transform_special.h" @@ -28,7 +33,7 @@ struct fast_hadamard_transform_kernel_traits { using vec_t = typename BytesToType::Type; static constexpr int kNChunks = N / (kNElts * kNThreads); // We don't want to use more than 32 KB of shared memory. - static constexpr int kSmemExchangeSize = std::min(N * 4, 32 * 1024); + static constexpr int kSmemExchangeSize = MIN(N * 4, 32 * 1024); static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); static constexpr int kSmemSize = kSmemExchangeSize; @@ -51,7 +56,7 @@ struct fast_hadamard_transform_12N_kernel_traits { static constexpr int kNChunks = N / (kNElts * kNThreads); static_assert(kNChunks == 12); // We don't want to use more than 24 KB of shared memory. - static constexpr int kSmemExchangeSize = std::min(N * 4, 24 * 1024); + static constexpr int kSmemExchangeSize = MIN(N * 4, 24 * 1024); static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); static constexpr int kSmemSize = kSmemExchangeSize; @@ -74,7 +79,7 @@ struct fast_hadamard_transform_20N_kernel_traits { static constexpr int kNChunks = N / (kNElts * kNThreads); static_assert(kNChunks == 20); // We don't want to use more than 40 KB of shared memory. - static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024); + static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024); static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); static constexpr int kSmemSize = kSmemExchangeSize; @@ -97,7 +102,7 @@ struct fast_hadamard_transform_28N_kernel_traits { static constexpr int kNChunks = N / (kNElts * kNThreads); static_assert(kNChunks == 28); // We don't want to use more than 28 KB of shared memory. - static constexpr int kSmemExchangeSize = std::min(N * 4, 28 * 1024); + static constexpr int kSmemExchangeSize = MIN(N * 4, 28 * 1024); static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); static constexpr int kSmemSize = kSmemExchangeSize; @@ -120,7 +125,7 @@ struct fast_hadamard_transform_40N_kernel_traits { static constexpr int kNChunks = N / (kNElts * kNThreads); static_assert(kNChunks == 40); // We don't want to use more than 40 KB of shared memory. - static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024); + static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024); static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); static constexpr int kSmemSize = kSmemExchangeSize; @@ -163,7 +168,7 @@ void fast_hadamard_transform_kernel(HadamardParamsBase params) { constexpr int kLogNElts = cilog2(Ktraits::kNElts); static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2"); - constexpr int kWarpSize = std::min(kNThreads, 32); + constexpr int kWarpSize = MIN(kNThreads, WARP_SIZE); constexpr int kLogWarpSize = cilog2(kWarpSize); static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2"); constexpr int kNWarps = kNThreads / kWarpSize; @@ -234,10 +239,12 @@ void fast_hadamard_transform_launch(HadamardParamsBase ¶ms, cudaStream_t str constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch); auto kernel = &fast_hadamard_transform_kernel; +#ifndef USE_ROCM if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } +#endif kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -279,10 +286,12 @@ void fast_hadamard_transform_12N_launch(HadamardParamsBase ¶ms, cudaStream_t constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch); auto kernel = &fast_hadamard_transform_kernel; +#ifndef USE_ROCM if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } +#endif kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -307,7 +316,7 @@ void fast_hadamard_transform_12N_cuda(HadamardParamsBase ¶ms, cudaStream_t s fast_hadamard_transform_12N_launch<128, 9, input_t>(params, stream); } else if (params.log_N == 10) { fast_hadamard_transform_12N_launch<256, 10, input_t>(params, stream); - } + } } template @@ -316,10 +325,12 @@ void fast_hadamard_transform_20N_launch(HadamardParamsBase ¶ms, cudaStream_t constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch); auto kernel = &fast_hadamard_transform_kernel; +#ifndef USE_ROCM if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } +#endif kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -353,10 +364,12 @@ void fast_hadamard_transform_28N_launch(HadamardParamsBase ¶ms, cudaStream_t constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch); auto kernel = &fast_hadamard_transform_kernel; +#ifndef USE_ROCM if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } +#endif kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -390,10 +403,12 @@ void fast_hadamard_transform_40N_launch(HadamardParamsBase ¶ms, cudaStream_t constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch); auto kernel = &fast_hadamard_transform_kernel; +#ifndef USE_ROCM if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } +#endif kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/csrc/vendor.h b/csrc/vendor.h new file mode 100644 index 0000000..3e0a168 --- /dev/null +++ b/csrc/vendor.h @@ -0,0 +1,37 @@ +#pragma once + +#if !defined(USE_MUSA) && !defined(USE_ROCM) +#include +#include + +#define WARP_SIZE 32 +#define MIN(A, B) std::min((A), (B)) +#elif defined(USE_MUSA) +#include +#include + +#define WARP_SIZE 32 +#define MIN(A, B) std::min((A), (B)) +#define C10_CUDA_CHECK C10_MUSA_CHECK +#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_MUSA_KERNEL_LAUNCH_CHECK +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaStream_t musaStream_t + +#include "torch_musa/csrc/core/MUSAGuard.h" +#include "torch_musa/csrc/core/MUSAStream.h" +namespace at { +namespace cuda { +#ifdef USE_MUSA +using CUDAGuard = at::musa::MUSAGuard; +inline at::musa::MUSAStream getCurrentCUDAStream() { + return at::musa::getCurrentMUSAStream(); +} +#endif +} // namespace cuda +} // namespace at +#elif defined(USE_ROCM) +#define WARP_SIZE 64 +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define __shfl_xor_sync(MASK, X, OFFSET) __shfl_xor(X, OFFSET) +#endif // !defined(USE_MUSA) && !defined(USE_ROCM) diff --git a/setup.py b/setup.py index 0e784f8..8c8207b 100644 --- a/setup.py +++ b/setup.py @@ -1,27 +1,54 @@ # Copyright (c) 2023, Tri Dao. -import sys -import warnings -import os -import re import ast -from pathlib import Path -from packaging.version import parse, Version +import os import platform - -from setuptools import setup, find_packages +import re import subprocess - -import urllib.request +import sys import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +import urllib.request +import warnings +from enum import Enum, auto +from pathlib import Path import torch -from torch.utils.cpp_extension import ( - BuildExtension, - CppExtension, - CUDAExtension, - CUDA_HOME, -) +from packaging.version import Version, parse +from setuptools import find_packages, setup +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + +class Backend(Enum): + CUDA = auto() + HIP = auto() + MUSA = auto() + + +backend = Backend.CUDA + +if hasattr(torch, "cuda") and ( + torch.version.cuda is not None or torch.version.hip is not None +): + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + + if torch.version.hip: + backend = Backend.HIP + + +elif hasattr(torch, "musa"): + import torch_musa + from torch_musa.utils.musa_extension import MUSA_HOME as CUDA_HOME + from torch_musa.utils.musa_extension import BuildExtension as MUSABuildExtension + from torch_musa.utils.musa_extension import MUSAExtension as CUDAExtension + + class _CustomBuildExtension(MUSABuildExtension): + def build_extensions(self): + self.compiler.src_extensions += [".cu", ".cuh"] + + super().build_extensions() + + BuildExtension = _CustomBuildExtension + + backend = Backend.MUSA with open("README.md", "r", encoding="utf-8") as fh: @@ -38,9 +65,13 @@ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("FAST_HADAMARD_TRANSFORM_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FAST_HADAMARD_TRANSFORM_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = ( + os.getenv("FAST_HADAMARD_TRANSFORM_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +) # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FAST_HADAMARD_TRANSFORM_FORCE_CXX11_ABI", "FALSE") == "TRUE" +FORCE_CXX11_ABI = ( + os.getenv("FAST_HADAMARD_TRANSFORM_FORCE_CXX11_ABI", "FALSE") == "TRUE" +) def get_platform(): @@ -82,7 +113,11 @@ def check_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(): - return ["--threads", os.getenv("NVCC_THREADS") or "4"] + if backend == Backend.CUDA: + return ["--threads", os.getenv("NVCC_THREADS") or "4"] + else: + return [] + cmdclass = {} ext_modules = [] @@ -92,29 +127,34 @@ def append_nvcc_threads(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - check_if_cuda_home_none("fast_hadamard_transform") - # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.6"): - raise RuntimeError( - "fast_hadamard_transform is only supported on CUDA 11.6 and above. " - "Note: make sure nvcc has a supported version by running nvcc -V." - ) - - if bare_metal_version <= Version("12.9"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.8"): + if backend == Backend.CUDA: + check_if_cuda_home_none("fast_hadamard_transform") + # Check, if CUDA11 is installed for compute capability 8.0 + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + "fast_hadamard_transform is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + if backend == Backend.CUDA: + if bare_metal_version <= Version("12.9"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + if bare_metal_version >= Version("12.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_100,code=sm_100") + elif backend == Backend.HIP: + cc_flag.append("-DUSE_ROCM=1") + elif backend == Backend.MUSA: + cc_flag.append("-DUSE_MUSA=1") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI @@ -122,32 +162,52 @@ def append_nvcc_threads(): if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + cxx_flags = ["-O3"] if backend in (Backend.CUDA, Backend.HIP) else ["force_mcc"] + backend_cc = "nvcc" if backend in (Backend.CUDA, Backend.HIP) else "mcc" + backend_cc_flags = [] + if backend == Backend.CUDA or backend == Backend.HIP: + backend_cc_flags = [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + ] + if backend == Backend.CUDA: + backend_cc_flags += [ + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + elif backend == Backend.MUSA: + backend_cc_flags = [ + "-O3", + "-fPIC", + "-std=c++17", + "-x", + "musa", + "-mtgpu", + "--cuda-gpu-arch=mp_31", + "-fno-strict-aliasing", + "-ffast-math", + "-Od3", + "-fmusa-flush-denormals-to-zero", + ] + ext_modules.append( CUDAExtension( name="fast_hadamard_transform_cuda", sources=[ "csrc/fast_hadamard_transform.cpp", - "csrc/fast_hadamard_transform_cuda.cu", + "csrc/fast_hadamard_transform_gpu.cu", ], extra_compile_args={ - "cxx": ["-O3"], - "nvcc": - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo", - ] - + append_nvcc_threads() - + cc_flag, + "cxx": cxx_flags, + backend_cc: backend_cc_flags + append_nvcc_threads() + cc_flag, }, include_dirs=[this_dir], ) @@ -169,11 +229,18 @@ def get_wheel_url(): # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - torch_cuda_version = parse(torch.version.cuda) + backend_versions = { + Backend.CUDA: torch.version.cuda, + Backend.HIP: torch.version.hip, + Backend.MUSA: getattr(torch.version, "musa", None), + } + torch_cuda_version = parse(backend_versions[backend]) torch_version_raw = parse(torch.__version__) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + torch_cuda_version = ( + parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + ) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() fast_hadamard_transform_version = get_package_version() @@ -252,11 +319,13 @@ def run(self): "Operating System :: Unix", ], ext_modules=ext_modules, - cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} - if ext_modules - else { - "bdist_wheel": CachedWheelsCommand, - }, + cmdclass=( + {"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, + } + ), python_requires=">=3.7", install_requires=[ "torch", diff --git a/tests/test_fast_hadamard_transform.py b/tests/test_fast_hadamard_transform.py index ea1e631..7db682a 100644 --- a/tests/test_fast_hadamard_transform.py +++ b/tests/test_fast_hadamard_transform.py @@ -16,6 +16,8 @@ # @pytest.mark.parametrize("dim", [256]) def test_fast_hadamard_transform(dim, dtype): device = "cuda" + if hasattr(torch, "musa"): + device = "musa" rtol, atol = (3e-4, 3e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 1e-2, 5e-2