diff --git a/plugins/gh/src/plugin.py b/plugins/gh/src/plugin.py index c93021c7..14acee22 100644 --- a/plugins/gh/src/plugin.py +++ b/plugins/gh/src/plugin.py @@ -8,6 +8,7 @@ import os import shutil import sys +import tempfile try: import yaml @@ -16,7 +17,17 @@ def get_config_path() -> str: - return os.path.join(os.environ.get("APPDATA", ""), "GitHub CLI", "config.yml") + appdata = os.environ.get("APPDATA") + + if appdata: + return os.path.join(appdata, "GitHub CLI", "config.yml") + + return os.path.join( + os.path.expanduser("~"), + ".config", + "gh", + "config.yml", + ) def log(message: str) -> None: @@ -32,8 +43,10 @@ def read_yaml(file_path: str) -> dict: return {} with open(file_path, "r", encoding="utf-8") as file_handle: - data = yaml.safe_load(file_handle) - return data if isinstance(data, dict) else {} + data = yaml.safe_load(file_handle) or {} + if not isinstance(data, dict): + return {} + return data def write_yaml(file_path: str, data: dict) -> None: @@ -41,15 +54,32 @@ def write_yaml(file_path: str, data: dict) -> None: raise RuntimeError("PyYAML is required to read or write gh config") os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "w", encoding="utf-8") as file_handle: - yaml.dump(data, file_handle, default_flow_style=False, sort_keys=False) + + dir_name = os.path.dirname(file_path) or "." + fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".yml") + + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + yaml.dump( + data, + f, + default_flow_style=False, + sort_keys=False, + ) + + os.replace(tmp_path, file_path) + + except Exception: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise def merge_settings(target: dict, source: dict) -> bool: changed = False for key, value in source.items(): - if value == "": + if value is None: continue current_value = target.get(key) @@ -75,19 +105,18 @@ def get_config_target(config: dict) -> dict: return config -def check_installed(request_id: str) -> dict: - installed = shutil.which("gh") is not None or shutil.which("gh.exe") is not None - return { - "requestId": request_id, - "success": True, - "changed": False, - "data": installed, - } +def check_installed() -> bool: + return ( shutil.which("gh") is not None or shutil.which("gh.exe") is not None ) -def apply_config(request_id: str, args: dict, context: dict) -> dict: - dry_run = bool(context.get("dryRun", False)) - updates = {key: value for key, value in args.items() if key != "dry_run"} +def apply_config(request_id: str, args: dict) -> dict: + dry_run = bool(args.get("dryRun", False)) + settings = args.get("settings", {}) + if not isinstance(settings, dict): + return { "requestId": request_id, + "error": "settings must be a dictionary", + } + updates = {key: value for key, value in settings.items()} config_path = get_config_path() if yaml is None: @@ -104,7 +133,6 @@ def apply_config(request_id: str, args: dict, context: dict) -> dict: log(f"dry_run: no changes for {config_path}") return { "requestId": request_id, - "success": True, "changed": changed, } @@ -113,32 +141,37 @@ def apply_config(request_id: str, args: dict, context: dict) -> dict: return { "requestId": request_id, - "success": True, "changed": changed, } def handle(request: dict) -> dict: - request_id = request.get("requestId", "unknown") + request_id = request.get("requestId") or "unknown" command = request.get("command") args = request.get("args", {}) - context = request.get("context", {}) if command == "check_installed": - return check_installed(request_id) + installed = check_installed() + return { "requestId": request_id, "installed": installed, } if command == "apply": if not isinstance(args, dict): - raise ValueError("args must be an object") - if not isinstance(context, dict): - raise ValueError("context must be an object") - return apply_config(request_id, args, context) + return { "requestId": request_id, + "error": "args must be a dictionary", + } + return apply_config(request_id, args) - raise ValueError(f"Unknown command: {command}") + return { + "requestId": request_id, + "error": f"Unknown command: {command}", + } def main() -> None: raw = sys.stdin.read() if not raw: + sys.stdout.write( + json.dumps({ "requestId": "unknown", "error": "No input received", }) + "\n" ) + sys.stdout.flush() return try: @@ -149,8 +182,6 @@ def main() -> None: "requestId": request.get("requestId", "unknown") if "request" in locals() and isinstance(request, dict) else "unknown", - "success": False, - "changed": False, "error": str(error), } diff --git a/plugins/gh/test/test_gh.py b/plugins/gh/test/test_gh.py index 1243e1a9..8ee5f933 100644 --- a/plugins/gh/test/test_gh.py +++ b/plugins/gh/test/test_gh.py @@ -15,7 +15,7 @@ import yaml -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) import plugin @@ -35,18 +35,16 @@ def test_check_installed_returns_true_when_gh_is_found(self): response = self.run_main({"requestId": "req-1", "command": "check_installed", "args": {}}) self.assertEqual(response["requestId"], "req-1") - self.assertTrue(response["success"]) self.assertFalse(response["changed"]) - self.assertTrue(response["data"]) + self.assertTrue(response["installed"]) def test_check_installed_returns_false_when_gh_is_missing(self): with patch("plugin.shutil.which", return_value=None): response = self.run_main({"requestId": "req-2", "command": "check_installed", "args": {}}) self.assertEqual(response["requestId"], "req-2") - self.assertTrue(response["success"]) self.assertFalse(response["changed"]) - self.assertFalse(response["data"]) + self.assertFalse(response["installed"]) def test_apply_writes_merged_config_and_returns_changed_true(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -67,19 +65,18 @@ def test_apply_writes_merged_config_and_returns_changed_true(self): "requestId": "req-3", "command": "apply", "args": { - "git_protocol": "https", - "editor": "code --wait", - "prompt": "enabled", - "pager": "less", - "http_unix_socket": "", - "browser": "", - }, - "context": {"dryRun": False}, + "dryRun": False, + "settings": { + "git_protocol": "https", + "editor": "code --wait", + "prompt": "enabled", + "pager": "less", + }, + }, } ) self.assertEqual(response["requestId"], "req-3") - self.assertTrue(response["success"]) self.assertTrue(response["changed"]) with open(config_path, "r", encoding="utf-8") as file_handle: @@ -118,19 +115,17 @@ def test_apply_with_no_changes_returns_changed_false(self): "requestId": "req-4", "command": "apply", "args": { - "git_protocol": "https", - "editor": "code --wait", - "prompt": "enabled", - "pager": "less", - "http_unix_socket": "", - "browser": "", + "dryRun": False, + "settings": { + "git_protocol": "https", + "editor": "code --wait", + "prompt": "enabled", + "pager": "less", + }, }, - "context": {"dryRun": False}, } ) - self.assertEqual(response["requestId"], "req-4") - self.assertTrue(response["success"]) self.assertFalse(response["changed"]) with open(config_path, "r", encoding="utf-8") as file_handle: @@ -147,13 +142,11 @@ def test_apply_with_dry_run_does_not_write_file(self): { "requestId": "req-5", "command": "apply", - "args": {"git_protocol": "https", "dry_run": True}, - "context": {"dryRun": True}, + "args": { "dryRun": True, "settings": { "git_protocol": "https" } }, } ) self.assertEqual(response["requestId"], "req-5") - self.assertTrue(response["success"]) self.assertTrue(response["changed"]) self.assertFalse(os.path.exists(config_path)) @@ -167,13 +160,11 @@ def test_apply_creates_missing_directory(self): { "requestId": "req-6", "command": "apply", - "args": {"git_protocol": "https"}, - "context": {"dryRun": False}, + "args": { "dryRun": False, "settings": {"git_protocol": "https" },}, } ) self.assertEqual(response["requestId"], "req-6") - self.assertTrue(response["success"]) self.assertTrue(response["changed"]) self.assertTrue(os.path.isdir(os.path.dirname(config_path))) self.assertTrue(os.path.exists(config_path)) @@ -184,13 +175,11 @@ def test_apply_returns_error_when_pyyaml_is_missing(self): { "requestId": "req-7", "command": "apply", - "args": {"git_protocol": "https"}, - "context": {"dryRun": False}, + "args": { "dryRun": False, "settings": { "git_protocol": "https" }, }, } ) self.assertEqual(response["requestId"], "req-7") - self.assertFalse(response["success"]) self.assertFalse(response["changed"]) self.assertIn("PyYAML", response["error"]) @@ -204,7 +193,6 @@ def test_unknown_command_returns_error(self): } ) self.assertEqual(response["requestId"], "req-8") - self.assertFalse(response["success"]) self.assertFalse(response["changed"]) self.assertIn("Unknown command", response["error"])