-
Notifications
You must be signed in to change notification settings - Fork 3
检测真实 cucc 版本 #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
检测真实 cucc 版本 #17
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Build-time helpers for FlashMLA.""" |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,48 @@ | ||||||||||||||||||
| import re | ||||||||||||||||||
| import subprocess | ||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||
| from typing import Callable | ||||||||||||||||||
|
|
||||||||||||||||||
| from packaging.version import Version | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def parse_cuda_release_version(output: str) -> Version: | ||||||||||||||||||
| release_match = re.search(r"release\s+(\d+\.\d+)", output) | ||||||||||||||||||
| if release_match: | ||||||||||||||||||
| return Version(release_match.group(1)) | ||||||||||||||||||
|
|
||||||||||||||||||
| version_match = re.search(r"\bV(\d+\.\d+)(?:\.\d+)?\b", output) | ||||||||||||||||||
| if version_match: | ||||||||||||||||||
| return Version(version_match.group(1)) | ||||||||||||||||||
|
|
||||||||||||||||||
| raise ValueError(f"Cannot parse compiler release version from output: {output!r}") | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _compiler_candidates(cuda_dir: str | Path) -> list[Path]: | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| bin_dir = Path(cuda_dir) / "bin" | ||||||||||||||||||
| return [bin_dir / "cucc", bin_dir / "nvcc"] | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_cuda_bare_metal_version( | ||||||||||||||||||
| cuda_dir: str | Path, | ||||||||||||||||||
| check_output: Callable[..., str] = subprocess.check_output, | ||||||||||||||||||
| ) -> tuple[str, Version]: | ||||||||||||||||||
|
Comment on lines
+26
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 使用
Suggested change
|
||||||||||||||||||
| errors: list[str] = [] | ||||||||||||||||||
| for compiler in _compiler_candidates(cuda_dir): | ||||||||||||||||||
| if not compiler.is_file(): | ||||||||||||||||||
| errors.append(f"{compiler} does not exist") | ||||||||||||||||||
| continue | ||||||||||||||||||
| try: | ||||||||||||||||||
| raw_output = check_output( | ||||||||||||||||||
| [str(compiler), "-V"], | ||||||||||||||||||
| stderr=subprocess.STDOUT, | ||||||||||||||||||
| universal_newlines=True, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return raw_output, parse_cuda_release_version(raw_output) | ||||||||||||||||||
| except Exception as err: | ||||||||||||||||||
| errors.append(f"{compiler}: {err}") | ||||||||||||||||||
|
|
||||||||||||||||||
| details = "\n".join(f"- {error}" for error in errors) | ||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||
| f"Cannot detect CUDA-compatible compiler version under {cuda_dir}.\n{details}" | ||||||||||||||||||
| ) | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,8 @@ | |
| CUDA_HOME, | ||
| ) | ||
|
|
||
| from build_tools.compiler_version import get_cuda_bare_metal_version | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新引入的 这会污染用户的 Python 环境,并且由于 建议在 packages=find_packages(
exclude=(
"build",
"build_tools", # 新增排除
"csrc",
...
)
),由于 |
||
|
|
||
|
|
||
| with open("README.md", "r", encoding="utf-8") as fh: | ||
| long_description = fh.read() | ||
|
|
@@ -63,16 +65,6 @@ def get_platform(): | |
| raise ValueError("Unsupported platform: {}".format(sys.platform)) | ||
|
|
||
|
|
||
| def get_cuda_bare_metal_version(cuda_dir): | ||
| # raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | ||
| # output = raw_output.split() | ||
| # release_idx = output.index("release") + 1 | ||
| # bare_metal_version = parse(output[release_idx].split(",")[0]) | ||
| raw_output = "nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:16:06_PDT_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0" | ||
| bare_metal_version = Version("12.1") | ||
| return raw_output, bare_metal_version | ||
|
|
||
|
|
||
| def check_if_cuda_home_none(global_option: str) -> None: | ||
| if CUDA_HOME is not None: | ||
| return | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import unittest | ||
| from pathlib import Path | ||
| from tempfile import TemporaryDirectory | ||
|
|
||
| from packaging.version import Version | ||
|
|
||
| from build_tools.compiler_version import ( | ||
| get_cuda_bare_metal_version, | ||
| parse_cuda_release_version, | ||
| ) | ||
|
|
||
|
|
||
| class CompilerVersionTest(unittest.TestCase): | ||
| def test_parse_nvcc_release_version(self): | ||
| output = "Cuda compilation tools, release 12.1, V12.1.105" | ||
|
|
||
| self.assertEqual(parse_cuda_release_version(output), Version("12.1")) | ||
|
|
||
| def test_parse_cucc_version_without_release_token(self): | ||
| output = "cucc compiler driver V12.2.91" | ||
|
|
||
| self.assertEqual(parse_cuda_release_version(output), Version("12.2")) | ||
|
|
||
| def test_get_cuda_bare_metal_version_prefers_cucc(self): | ||
| with TemporaryDirectory() as tmp_dir: | ||
| bin_dir = Path(tmp_dir) / "bin" | ||
| bin_dir.mkdir() | ||
| cucc = bin_dir / "cucc" | ||
| nvcc = bin_dir / "nvcc" | ||
| cucc.write_text("#!/bin/sh\n", encoding="utf-8") | ||
| nvcc.write_text("#!/bin/sh\n", encoding="utf-8") | ||
|
|
||
| commands: list[list[str]] = [] | ||
|
|
||
| def fake_check_output(command, **_kwargs): | ||
| commands.append(command) | ||
| return "cucc compiler driver V12.3.0" | ||
|
|
||
| raw_output, version = get_cuda_bare_metal_version( | ||
| tmp_dir, | ||
| check_output=fake_check_output, | ||
| ) | ||
|
|
||
| self.assertEqual(commands, [[str(cucc), "-V"]]) | ||
| self.assertIn("cucc", raw_output) | ||
| self.assertEqual(version, Version("12.3")) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于项目支持 Python 3.7+,直接使用
str | Path语法(PEP 604)在 Python 3.10 以下版本导入时会抛出TypeError。建议在此处导入Union以支持低版本 Python。