-
Notifications
You must be signed in to change notification settings - Fork 3
增加 MACA 环境诊断工具 #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
增加 MACA 环境诊断工具 #12
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import importlib.util | ||
| import unittest | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| DOCTOR_PATH = Path(__file__).resolve().parents[1] / "tools" / "maca_env_doctor.py" | ||
| spec = importlib.util.spec_from_file_location("maca_env_doctor", DOCTOR_PATH) | ||
| maca_env_doctor = importlib.util.module_from_spec(spec) | ||
| assert spec.loader is not None | ||
| spec.loader.exec_module(maca_env_doctor) | ||
|
|
||
|
|
||
| class MacaEnvDoctorTest(unittest.TestCase): | ||
| def test_collect_report_marks_missing_environment(self): | ||
| report = maca_env_doctor.collect_report({}) | ||
|
|
||
| self.assertFalse(report["ok"]) | ||
| self.assertIn("environment", report) | ||
| self.assertTrue(any(item["name"] == "MACA_PATH" for item in report["checks"])) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| #!/usr/bin/env python3 | ||
| import argparse | ||
| import json | ||
| import os | ||
| import shutil | ||
| import subprocess | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| def _path_status(path: str | None) -> dict[str, Any]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if not path: | ||
| return {"path": None, "exists": False, "is_dir": False} | ||
| resolved = Path(path).expanduser() | ||
| return { | ||
| "path": str(resolved), | ||
| "exists": resolved.exists(), | ||
| "is_dir": resolved.is_dir(), | ||
| } | ||
|
|
||
|
|
||
| def _command_output(command: list[str]) -> dict[str, Any]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| try: | ||
| result = subprocess.run( | ||
| command, | ||
| check=False, | ||
| capture_output=True, | ||
| text=True, | ||
| timeout=10, | ||
| ) | ||
| return { | ||
| "command": command, | ||
| "returncode": result.returncode, | ||
| "stdout": result.stdout.strip(), | ||
| "stderr": result.stderr.strip(), | ||
| } | ||
| except Exception as exc: | ||
| return {"command": command, "error": f"{type(exc).__name__}: {exc}"} | ||
|
|
||
|
|
||
| def collect_report(env: dict[str, str] | None = None) -> dict[str, Any]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| env = os.environ if env is None else env | ||
| maca_path = env.get("MACA_PATH") | ||
| cuda_path = env.get("CUDA_HOME") or env.get("CUDA_PATH") | ||
| if not cuda_path and maca_path: | ||
| cuda_path = str(Path(maca_path) / "tools" / "cu-bridge") | ||
| clang_path = env.get("MACA_CLANG_PATH") | ||
| if not clang_path and maca_path: | ||
| clang_path = str(Path(maca_path) / "mxgpu_llvm" / "bin") | ||
|
|
||
| search_path = env.get("PATH", "") | ||
| cucc = shutil.which("cucc", path=search_path) | ||
| nvcc = shutil.which("nvcc", path=search_path) | ||
| report: dict[str, Any] = { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "python": sys.version.split()[0], | ||
| "environment": { | ||
| "MACA_PATH": _path_status(maca_path), | ||
| "CUDA_HOME_OR_PATH": _path_status(cuda_path), | ||
| "MACA_CLANG_PATH": _path_status(clang_path), | ||
| "LD_LIBRARY_PATH_SET": bool(env.get("LD_LIBRARY_PATH")), | ||
| }, | ||
| "executables": { | ||
| "cucc": cucc, | ||
| "nvcc": nvcc, | ||
| }, | ||
| "libraries": {}, | ||
| "torch": {}, | ||
| "checks": [], | ||
| } | ||
|
|
||
| if maca_path: | ||
| maca_root = Path(maca_path) | ||
| report["libraries"] = { | ||
| "mcruntime": _path_status(str(maca_root / "lib" / "libmcruntime.so")), | ||
| "mcblas": _path_status(str(maca_root / "lib" / "libmcblas.so")), | ||
| } | ||
|
|
||
| if cucc: | ||
| report["cucc_version"] = _command_output([cucc, "-V"]) | ||
| elif nvcc: | ||
| report["nvcc_version"] = _command_output([nvcc, "-V"]) | ||
|
|
||
| try: | ||
| import torch | ||
|
|
||
| report["torch"] = { | ||
| "version": torch.__version__, | ||
| "cuda_version": getattr(torch.version, "cuda", None), | ||
| "maca_version": getattr(torch.version, "maca", None), | ||
| "cuda_available": bool(torch.cuda.is_available()), | ||
| "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | ||
| } | ||
| if torch.cuda.is_available(): | ||
| report["torch"]["device_name_0"] = torch.cuda.get_device_name(0) | ||
| except Exception as exc: | ||
| report["torch"] = {"error": f"{type(exc).__name__}: {exc}"} | ||
|
|
||
| required = [ | ||
| ("MACA_PATH", report["environment"]["MACA_PATH"]["is_dir"]), | ||
| ("CUDA_HOME_OR_PATH", report["environment"]["CUDA_HOME_OR_PATH"]["is_dir"]), | ||
| ("MACA_CLANG_PATH", report["environment"]["MACA_CLANG_PATH"]["is_dir"]), | ||
| ("cucc_or_nvcc", bool(cucc or nvcc)), | ||
| ] | ||
| for name, ok in required: | ||
| report["checks"].append({"name": name, "ok": ok}) | ||
| report["ok"] = all(item["ok"] for item in report["checks"]) | ||
| return report | ||
|
|
||
|
|
||
| def main() -> int: | ||
| parser = argparse.ArgumentParser(description="Collect FlashMLA MACA environment diagnostics.") | ||
| parser.add_argument("--pretty", action="store_true", help="Print indented JSON.") | ||
| args = parser.parse_args() | ||
|
|
||
| report = collect_report() | ||
| print(json.dumps(report, indent=2 if args.pretty else None, sort_keys=True)) | ||
| return 0 if report["ok"] else 1 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| raise SystemExit(main()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果
importlib.util.spec_from_file_location无法找到或加载模块规范,它可能会返回None。在这种情况下,直接调用module_from_spec(spec)会在到达第 9 行的断言之前抛出TypeError。我们应该在尝试加载模块之前,先断言spec不为None。