Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Build-time helpers for FlashMLA."""
48 changes: 48 additions & 0 deletions build_tools/compiler_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import re
import subprocess
from pathlib import Path
from typing import Callable
Comment on lines +3 to +4

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

由于项目支持 Python 3.7+,直接使用 str | Path 语法(PEP 604)在 Python 3.10 以下版本导入时会抛出 TypeError。建议在此处导入 Union 以支持低版本 Python。

Suggested change
from pathlib import Path
from typing import Callable
from pathlib import Path
from typing import Callable, Union


from packaging.version import Version


def parse_cuda_release_version(output: str) -> Version:
release_match = re.search(r"release\s+(\d+\.\d+)", output)
if release_match:
return Version(release_match.group(1))

version_match = re.search(r"\bV(\d+\.\d+)(?:\.\d+)?\b", output)
if version_match:
return Version(version_match.group(1))

raise ValueError(f"Cannot parse compiler release version from output: {output!r}")


def _compiler_candidates(cuda_dir: str | Path) -> list[Path]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

使用 Union[str, Path] 代替 str | Path 以确保在 Python 3.7/3.8/3.9 环境下的兼容性。

Suggested change
def _compiler_candidates(cuda_dir: str | Path) -> list[Path]:
def _compiler_candidates(cuda_dir: Union[str, Path]) -> list[Path]:

bin_dir = Path(cuda_dir) / "bin"
return [bin_dir / "cucc", bin_dir / "nvcc"]


def get_cuda_bare_metal_version(
cuda_dir: str | Path,
check_output: Callable[..., str] = subprocess.check_output,
) -> tuple[str, Version]:
Comment on lines +26 to +29

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

使用 Union[str, Path] 代替 str | Path 以确保在 Python 3.7/3.8/3.9 环境下的兼容性。

Suggested change
def get_cuda_bare_metal_version(
cuda_dir: str | Path,
check_output: Callable[..., str] = subprocess.check_output,
) -> tuple[str, Version]:
def get_cuda_bare_metal_version(
cuda_dir: Union[str, Path],
check_output: Callable[..., str] = subprocess.check_output,
) -> tuple[str, Version]:

errors: list[str] = []
for compiler in _compiler_candidates(cuda_dir):
if not compiler.is_file():
errors.append(f"{compiler} does not exist")
continue
try:
raw_output = check_output(
[str(compiler), "-V"],
stderr=subprocess.STDOUT,
universal_newlines=True,
)
return raw_output, parse_cuda_release_version(raw_output)
except Exception as err:
errors.append(f"{compiler}: {err}")

details = "\n".join(f"- {error}" for error in errors)
raise RuntimeError(
f"Cannot detect CUDA-compatible compiler version under {cuda_dir}.\n{details}"
)
12 changes: 2 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
CUDA_HOME,
)

from build_tools.compiler_version import get_cuda_bare_metal_version

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

新引入的 build_tools 目录中包含了 __init__.py,这使其被识别为一个 Python 包。由于 setup.py 中的 find_packages 调用没有将 "build_tools" 排除,这会导致在构建和安装 flash_mla 时,build_tools 会作为一个顶级的包被安装到用户的 site-packages 目录中。

这会污染用户的 Python 环境,并且由于 build_tools 是一个非常通用的名称,极易与其他库产生命名冲突。

建议在 setup.pyfind_packages(exclude=(...)) 中将 "build_tools" 排除。例如:

    packages=find_packages(
        exclude=(
            "build",
            "build_tools",  # 新增排除
            "csrc",
            ...
        )
    ),

由于 setup.py 在运行时处于项目根目录下,即使排除了该包,构建时依然可以正常导入并使用它。



with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
Expand Down Expand Up @@ -63,16 +65,6 @@ def get_platform():
raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_cuda_bare_metal_version(cuda_dir):
# raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
# output = raw_output.split()
# release_idx = output.index("release") + 1
# bare_metal_version = parse(output[release_idx].split(",")[0])
raw_output = "nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:16:06_PDT_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0"
bare_metal_version = Version("12.1")
return raw_output, bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
Expand Down
50 changes: 50 additions & 0 deletions tests/test_compiler_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

from packaging.version import Version

from build_tools.compiler_version import (
get_cuda_bare_metal_version,
parse_cuda_release_version,
)


class CompilerVersionTest(unittest.TestCase):
def test_parse_nvcc_release_version(self):
output = "Cuda compilation tools, release 12.1, V12.1.105"

self.assertEqual(parse_cuda_release_version(output), Version("12.1"))

def test_parse_cucc_version_without_release_token(self):
output = "cucc compiler driver V12.2.91"

self.assertEqual(parse_cuda_release_version(output), Version("12.2"))

def test_get_cuda_bare_metal_version_prefers_cucc(self):
with TemporaryDirectory() as tmp_dir:
bin_dir = Path(tmp_dir) / "bin"
bin_dir.mkdir()
cucc = bin_dir / "cucc"
nvcc = bin_dir / "nvcc"
cucc.write_text("#!/bin/sh\n", encoding="utf-8")
nvcc.write_text("#!/bin/sh\n", encoding="utf-8")

commands: list[list[str]] = []

def fake_check_output(command, **_kwargs):
commands.append(command)
return "cucc compiler driver V12.3.0"

raw_output, version = get_cuda_bare_metal_version(
tmp_dir,
check_output=fake_check_output,
)

self.assertEqual(commands, [[str(cucc), "-V"]])
self.assertIn("cucc", raw_output)
self.assertEqual(version, Version("12.3"))


if __name__ == "__main__":
unittest.main()