Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/test_maca_env_doctor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import importlib.util
import unittest
from pathlib import Path


DOCTOR_PATH = Path(__file__).resolve().parents[1] / "tools" / "maca_env_doctor.py"
spec = importlib.util.spec_from_file_location("maca_env_doctor", DOCTOR_PATH)
maca_env_doctor = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(maca_env_doctor)
Comment on lines +7 to +10

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

如果 importlib.util.spec_from_file_location 无法找到或加载模块规范,它可能会返回 None。在这种情况下,直接调用 module_from_spec(spec) 会在到达第 9 行的断言之前抛出 TypeError。我们应该在尝试加载模块之前,先断言 spec 不为 None

Suggested change
spec = importlib.util.spec_from_file_location("maca_env_doctor", DOCTOR_PATH)
maca_env_doctor = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(maca_env_doctor)
spec = importlib.util.spec_from_file_location("maca_env_doctor", DOCTOR_PATH)
assert spec is not None and spec.loader is not None
maca_env_doctor = importlib.util.module_from_spec(spec)
spec.loader.exec_module(maca_env_doctor)



class MacaEnvDoctorTest(unittest.TestCase):
def test_collect_report_marks_missing_environment(self):
report = maca_env_doctor.collect_report({})

self.assertFalse(report["ok"])
self.assertIn("environment", report)
self.assertTrue(any(item["name"] == "MACA_PATH" for item in report["checks"]))


if __name__ == "__main__":
unittest.main()
122 changes: 122 additions & 0 deletions tools/maca_env_doctor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Any

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

为了保持与 Python 3.7 和 3.8 的兼容性(如 setup.py 中指定的 python_requires=">=3.7"),我们应该从 typing 导入 DictListOptional,而不是使用 PEP 585/604 风格的类型注解(如 dict[...]list[...]|),因为这些特性分别在 Python 3.9+ 和 3.10+ 中才被支持。

Suggested change
from typing import Any
from typing import Any, Dict, List, Optional



def _path_status(path: str | None) -> dict[str, Any]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

更新类型注解以使用 OptionalDict,从而确保与 Python 3.7 和 3.8 的兼容性。

Suggested change
def _path_status(path: str | None) -> dict[str, Any]:
def _path_status(path: Optional[str]) -> Dict[str, Any]:

if not path:
return {"path": None, "exists": False, "is_dir": False}
resolved = Path(path).expanduser()
return {
"path": str(resolved),
"exists": resolved.exists(),
"is_dir": resolved.is_dir(),
}


def _command_output(command: list[str]) -> dict[str, Any]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

更新类型注解以使用 ListDict,从而确保与 Python 3.7 和 3.8 的兼容性。

Suggested change
def _command_output(command: list[str]) -> dict[str, Any]:
def _command_output(command: List[str]) -> Dict[str, Any]:

try:
result = subprocess.run(
command,
check=False,
capture_output=True,
text=True,
timeout=10,
)
return {
"command": command,
"returncode": result.returncode,
"stdout": result.stdout.strip(),
"stderr": result.stderr.strip(),
}
except Exception as exc:
return {"command": command, "error": f"{type(exc).__name__}: {exc}"}


def collect_report(env: dict[str, str] | None = None) -> dict[str, Any]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

更新类型注解以使用 OptionalDict,从而确保与 Python 3.7 和 3.8 的兼容性。

Suggested change
def collect_report(env: dict[str, str] | None = None) -> dict[str, Any]:
def collect_report(env: Optional[Dict[str, str]] = None) -> Dict[str, Any]:

env = os.environ if env is None else env
maca_path = env.get("MACA_PATH")
cuda_path = env.get("CUDA_HOME") or env.get("CUDA_PATH")
if not cuda_path and maca_path:
cuda_path = str(Path(maca_path) / "tools" / "cu-bridge")
clang_path = env.get("MACA_CLANG_PATH")
if not clang_path and maca_path:
clang_path = str(Path(maca_path) / "mxgpu_llvm" / "bin")

search_path = env.get("PATH", "")
cucc = shutil.which("cucc", path=search_path)
nvcc = shutil.which("nvcc", path=search_path)
report: dict[str, Any] = {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

更新变量类型注解以使用 Dict,从而确保与 Python 3.7 和 3.8 的兼容性。

Suggested change
report: dict[str, Any] = {
report: Dict[str, Any] = {

"python": sys.version.split()[0],
"environment": {
"MACA_PATH": _path_status(maca_path),
"CUDA_HOME_OR_PATH": _path_status(cuda_path),
"MACA_CLANG_PATH": _path_status(clang_path),
"LD_LIBRARY_PATH_SET": bool(env.get("LD_LIBRARY_PATH")),
},
"executables": {
"cucc": cucc,
"nvcc": nvcc,
},
"libraries": {},
"torch": {},
"checks": [],
}

if maca_path:
maca_root = Path(maca_path)
report["libraries"] = {
"mcruntime": _path_status(str(maca_root / "lib" / "libmcruntime.so")),
"mcblas": _path_status(str(maca_root / "lib" / "libmcblas.so")),
}

if cucc:
report["cucc_version"] = _command_output([cucc, "-V"])
elif nvcc:
report["nvcc_version"] = _command_output([nvcc, "-V"])

try:
import torch

report["torch"] = {
"version": torch.__version__,
"cuda_version": getattr(torch.version, "cuda", None),
"maca_version": getattr(torch.version, "maca", None),
"cuda_available": bool(torch.cuda.is_available()),
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
}
if torch.cuda.is_available():
report["torch"]["device_name_0"] = torch.cuda.get_device_name(0)
except Exception as exc:
report["torch"] = {"error": f"{type(exc).__name__}: {exc}"}

required = [
("MACA_PATH", report["environment"]["MACA_PATH"]["is_dir"]),
("CUDA_HOME_OR_PATH", report["environment"]["CUDA_HOME_OR_PATH"]["is_dir"]),
("MACA_CLANG_PATH", report["environment"]["MACA_CLANG_PATH"]["is_dir"]),
("cucc_or_nvcc", bool(cucc or nvcc)),
]
for name, ok in required:
report["checks"].append({"name": name, "ok": ok})
report["ok"] = all(item["ok"] for item in report["checks"])
return report


def main() -> int:
parser = argparse.ArgumentParser(description="Collect FlashMLA MACA environment diagnostics.")
parser.add_argument("--pretty", action="store_true", help="Print indented JSON.")
args = parser.parse_args()

report = collect_report()
print(json.dumps(report, indent=2 if args.pretty else None, sort_keys=True))
return 0 if report["ok"] else 1


if __name__ == "__main__":
raise SystemExit(main())