From 78ff44007d782e1218f3ba50a5f6c78f9e0acb35 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Thu, 4 Jun 2026 21:36:45 +0800 Subject: [PATCH] Detect real cucc compiler version --- build_tools/__init__.py | 1 + build_tools/compiler_version.py | 48 +++++++++++++++++++++++++++++++ setup.py | 12 ++------ tests/test_compiler_version.py | 50 +++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 build_tools/__init__.py create mode 100644 build_tools/compiler_version.py create mode 100644 tests/test_compiler_version.py diff --git a/build_tools/__init__.py b/build_tools/__init__.py new file mode 100644 index 00000000..5cb0db32 --- /dev/null +++ b/build_tools/__init__.py @@ -0,0 +1 @@ +"""Build-time helpers for FlashMLA.""" diff --git a/build_tools/compiler_version.py b/build_tools/compiler_version.py new file mode 100644 index 00000000..b8f3e4cd --- /dev/null +++ b/build_tools/compiler_version.py @@ -0,0 +1,48 @@ +import re +import subprocess +from pathlib import Path +from typing import Callable + +from packaging.version import Version + + +def parse_cuda_release_version(output: str) -> Version: + release_match = re.search(r"release\s+(\d+\.\d+)", output) + if release_match: + return Version(release_match.group(1)) + + version_match = re.search(r"\bV(\d+\.\d+)(?:\.\d+)?\b", output) + if version_match: + return Version(version_match.group(1)) + + raise ValueError(f"Cannot parse compiler release version from output: {output!r}") + + +def _compiler_candidates(cuda_dir: str | Path) -> list[Path]: + bin_dir = Path(cuda_dir) / "bin" + return [bin_dir / "cucc", bin_dir / "nvcc"] + + +def get_cuda_bare_metal_version( + cuda_dir: str | Path, + check_output: Callable[..., str] = subprocess.check_output, +) -> tuple[str, Version]: + errors: list[str] = [] + for compiler in _compiler_candidates(cuda_dir): + if not compiler.is_file(): + errors.append(f"{compiler} does not exist") + continue + try: + raw_output = check_output( + [str(compiler), "-V"], + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + return raw_output, parse_cuda_release_version(raw_output) + except Exception as err: + errors.append(f"{compiler}: {err}") + + details = "\n".join(f"- {error}" for error in errors) + raise RuntimeError( + f"Cannot detect CUDA-compatible compiler version under {cuda_dir}.\n{details}" + ) diff --git a/setup.py b/setup.py index deb687eb..7a80351f 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ CUDA_HOME, ) +from build_tools.compiler_version import get_cuda_bare_metal_version + with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -63,16 +65,6 @@ def get_platform(): raise ValueError("Unsupported platform: {}".format(sys.platform)) -def get_cuda_bare_metal_version(cuda_dir): - # raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - # output = raw_output.split() - # release_idx = output.index("release") + 1 - # bare_metal_version = parse(output[release_idx].split(",")[0]) - raw_output = "nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:16:06_PDT_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0" - bare_metal_version = Version("12.1") - return raw_output, bare_metal_version - - def check_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return diff --git a/tests/test_compiler_version.py b/tests/test_compiler_version.py new file mode 100644 index 00000000..6e236087 --- /dev/null +++ b/tests/test_compiler_version.py @@ -0,0 +1,50 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from packaging.version import Version + +from build_tools.compiler_version import ( + get_cuda_bare_metal_version, + parse_cuda_release_version, +) + + +class CompilerVersionTest(unittest.TestCase): + def test_parse_nvcc_release_version(self): + output = "Cuda compilation tools, release 12.1, V12.1.105" + + self.assertEqual(parse_cuda_release_version(output), Version("12.1")) + + def test_parse_cucc_version_without_release_token(self): + output = "cucc compiler driver V12.2.91" + + self.assertEqual(parse_cuda_release_version(output), Version("12.2")) + + def test_get_cuda_bare_metal_version_prefers_cucc(self): + with TemporaryDirectory() as tmp_dir: + bin_dir = Path(tmp_dir) / "bin" + bin_dir.mkdir() + cucc = bin_dir / "cucc" + nvcc = bin_dir / "nvcc" + cucc.write_text("#!/bin/sh\n", encoding="utf-8") + nvcc.write_text("#!/bin/sh\n", encoding="utf-8") + + commands: list[list[str]] = [] + + def fake_check_output(command, **_kwargs): + commands.append(command) + return "cucc compiler driver V12.3.0" + + raw_output, version = get_cuda_bare_metal_version( + tmp_dir, + check_output=fake_check_output, + ) + + self.assertEqual(commands, [[str(cucc), "-V"]]) + self.assertIn("cucc", raw_output) + self.assertEqual(version, Version("12.3")) + + +if __name__ == "__main__": + unittest.main()