Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions plugins/gh/src/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import shutil
import sys
import tempfile

try:
import yaml
Expand All @@ -16,7 +17,17 @@


def get_config_path() -> str:
return os.path.join(os.environ.get("APPDATA", ""), "GitHub CLI", "config.yml")
appdata = os.environ.get("APPDATA")

if appdata:
return os.path.join(appdata, "GitHub CLI", "config.yml")

return os.path.join(
os.path.expanduser("~"),
".config",
"gh",
"config.yml",
)


def log(message: str) -> None:
Expand All @@ -32,24 +43,43 @@ def read_yaml(file_path: str) -> dict:
return {}

with open(file_path, "r", encoding="utf-8") as file_handle:
data = yaml.safe_load(file_handle)
return data if isinstance(data, dict) else {}
data = yaml.safe_load(file_handle) or {}
if not isinstance(data, dict):
return {}
return data


def write_yaml(file_path: str, data: dict) -> None:
if yaml is None:
raise RuntimeError("PyYAML is required to read or write gh config")

os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as file_handle:
yaml.dump(data, file_handle, default_flow_style=False, sort_keys=False)

dir_name = os.path.dirname(file_path) or "."
fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".yml")

try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
yaml.dump(
data,
f,
default_flow_style=False,
sort_keys=False,
)

os.replace(tmp_path, file_path)

except Exception:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise


def merge_settings(target: dict, source: dict) -> bool:
changed = False

for key, value in source.items():
if value == "":
if value is None:
continue

current_value = target.get(key)
Expand All @@ -75,19 +105,18 @@ def get_config_target(config: dict) -> dict:
return config


def check_installed(request_id: str) -> dict:
installed = shutil.which("gh") is not None or shutil.which("gh.exe") is not None
return {
"requestId": request_id,
"success": True,
"changed": False,
"data": installed,
}
def check_installed() -> bool:
return ( shutil.which("gh") is not None or shutil.which("gh.exe") is not None )


def apply_config(request_id: str, args: dict, context: dict) -> dict:
dry_run = bool(context.get("dryRun", False))
updates = {key: value for key, value in args.items() if key != "dry_run"}
def apply_config(request_id: str, args: dict) -> dict:
dry_run = bool(args.get("dryRun", False))
settings = args.get("settings", {})
if not isinstance(settings, dict):
return { "requestId": request_id,
"error": "settings must be a dictionary",
}
updates = {key: value for key, value in settings.items()}

config_path = get_config_path()
if yaml is None:
Expand All @@ -104,7 +133,6 @@ def apply_config(request_id: str, args: dict, context: dict) -> dict:
log(f"dry_run: no changes for {config_path}")
return {
"requestId": request_id,
"success": True,
"changed": changed,
}

Expand All @@ -113,32 +141,37 @@ def apply_config(request_id: str, args: dict, context: dict) -> dict:

return {
"requestId": request_id,
"success": True,
"changed": changed,
}


def handle(request: dict) -> dict:
request_id = request.get("requestId", "unknown")
request_id = request.get("requestId") or "unknown"
command = request.get("command")
args = request.get("args", {})
context = request.get("context", {})

if command == "check_installed":
return check_installed(request_id)
installed = check_installed()
return { "requestId": request_id, "installed": installed, }
if command == "apply":
if not isinstance(args, dict):
raise ValueError("args must be an object")
if not isinstance(context, dict):
raise ValueError("context must be an object")
return apply_config(request_id, args, context)
return { "requestId": request_id,
"error": "args must be a dictionary",
}
return apply_config(request_id, args)

raise ValueError(f"Unknown command: {command}")
return {
"requestId": request_id,
"error": f"Unknown command: {command}",
}


def main() -> None:
raw = sys.stdin.read()
if not raw:
sys.stdout.write(
json.dumps({ "requestId": "unknown", "error": "No input received", }) + "\n" )
sys.stdout.flush()
return

try:
Expand All @@ -149,8 +182,6 @@ def main() -> None:
"requestId": request.get("requestId", "unknown")
if "request" in locals() and isinstance(request, dict)
else "unknown",
"success": False,
"changed": False,
"error": str(error),
}

Expand Down
54 changes: 21 additions & 33 deletions plugins/gh/test/test_gh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import yaml

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
import plugin


Expand All @@ -35,18 +35,16 @@ def test_check_installed_returns_true_when_gh_is_found(self):
response = self.run_main({"requestId": "req-1", "command": "check_installed", "args": {}})

self.assertEqual(response["requestId"], "req-1")
self.assertTrue(response["success"])
self.assertFalse(response["changed"])
self.assertTrue(response["data"])
self.assertTrue(response["installed"])

def test_check_installed_returns_false_when_gh_is_missing(self):
with patch("plugin.shutil.which", return_value=None):
response = self.run_main({"requestId": "req-2", "command": "check_installed", "args": {}})

self.assertEqual(response["requestId"], "req-2")
self.assertTrue(response["success"])
self.assertFalse(response["changed"])
self.assertFalse(response["data"])
self.assertFalse(response["installed"])

def test_apply_writes_merged_config_and_returns_changed_true(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -67,19 +65,18 @@ def test_apply_writes_merged_config_and_returns_changed_true(self):
"requestId": "req-3",
"command": "apply",
"args": {
"git_protocol": "https",
"editor": "code --wait",
"prompt": "enabled",
"pager": "less",
"http_unix_socket": "",
"browser": "",
},
"context": {"dryRun": False},
"dryRun": False,
"settings": {
"git_protocol": "https",
"editor": "code --wait",
"prompt": "enabled",
"pager": "less",
},
},
}
)

self.assertEqual(response["requestId"], "req-3")
self.assertTrue(response["success"])
self.assertTrue(response["changed"])

with open(config_path, "r", encoding="utf-8") as file_handle:
Expand Down Expand Up @@ -118,19 +115,17 @@ def test_apply_with_no_changes_returns_changed_false(self):
"requestId": "req-4",
"command": "apply",
"args": {
"git_protocol": "https",
"editor": "code --wait",
"prompt": "enabled",
"pager": "less",
"http_unix_socket": "",
"browser": "",
"dryRun": False,
"settings": {
"git_protocol": "https",
"editor": "code --wait",
"prompt": "enabled",
"pager": "less",
},
},
"context": {"dryRun": False},
}
)

self.assertEqual(response["requestId"], "req-4")
self.assertTrue(response["success"])
self.assertFalse(response["changed"])

with open(config_path, "r", encoding="utf-8") as file_handle:
Expand All @@ -147,13 +142,11 @@ def test_apply_with_dry_run_does_not_write_file(self):
{
"requestId": "req-5",
"command": "apply",
"args": {"git_protocol": "https", "dry_run": True},
"context": {"dryRun": True},
"args": { "dryRun": True, "settings": { "git_protocol": "https" } },
}
)

self.assertEqual(response["requestId"], "req-5")
self.assertTrue(response["success"])
self.assertTrue(response["changed"])
self.assertFalse(os.path.exists(config_path))

Expand All @@ -167,13 +160,11 @@ def test_apply_creates_missing_directory(self):
{
"requestId": "req-6",
"command": "apply",
"args": {"git_protocol": "https"},
"context": {"dryRun": False},
"args": { "dryRun": False, "settings": {"git_protocol": "https" },},
}
)

self.assertEqual(response["requestId"], "req-6")
self.assertTrue(response["success"])
self.assertTrue(response["changed"])
self.assertTrue(os.path.isdir(os.path.dirname(config_path)))
self.assertTrue(os.path.exists(config_path))
Expand All @@ -184,13 +175,11 @@ def test_apply_returns_error_when_pyyaml_is_missing(self):
{
"requestId": "req-7",
"command": "apply",
"args": {"git_protocol": "https"},
"context": {"dryRun": False},
"args": { "dryRun": False, "settings": { "git_protocol": "https" }, },
}
)

self.assertEqual(response["requestId"], "req-7")
self.assertFalse(response["success"])
self.assertFalse(response["changed"])
self.assertIn("PyYAML", response["error"])

Expand All @@ -204,7 +193,6 @@ def test_unknown_command_returns_error(self):
}
)
self.assertEqual(response["requestId"], "req-8")
self.assertFalse(response["success"])
self.assertFalse(response["changed"])
self.assertIn("Unknown command", response["error"])

Expand Down
Loading