Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pr_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def set_parser():
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
parser.add_argument('--config-branch', type=str, help='Git branch to load .pr_agent.toml from', default=None)
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser
Expand All @@ -76,6 +77,9 @@ def run(inargs=None, args=None):

command = args.command.lower()
get_settings().set("CONFIG.CLI_MODE", True)
config_branch = (args.config_branch or os.environ.get("PR_AGENT_CONFIG_BRANCH") or "").strip()
if config_branch:
get_settings().set("CONFIG.CONFIG_BRANCH", config_branch)

async def inner():
if args.issue_url:
Expand Down
22 changes: 18 additions & 4 deletions pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import difflib
import hashlib
import itertools
import json
import os
import re
import time
import traceback
import json
from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urlparse
Expand Down Expand Up @@ -731,10 +732,23 @@ def get_issue_comments(self):
return self.pr.get_issue_comments()

def get_repo_settings(self):
config_branch = get_settings().get("CONFIG.CONFIG_BRANCH", None) or os.environ.get("PR_AGENT_CONFIG_BRANCH")
if isinstance(config_branch, str):
config_branch = config_branch.strip()
if config_branch:
try:
return self.repo_obj.get_contents(".pr_agent.toml", ref=config_branch).decoded_content
except GithubException as e:
get_logger().warning(
f"Failed to load .pr_agent.toml from branch '{config_branch}', falling back to default branch",
artifact={"status": e.status, "error": str(e)},
)
except Exception as e:
get_logger().warning(
f"Failed to load .pr_agent.toml from branch '{config_branch}', falling back to default branch",
artifact={"error": str(e)},
)
try:
# contents = self.repo_obj.get_contents(".pr_agent.toml", ref=self.pr.head.sha).decoded_content

# more logical to take 'pr_agent.toml' from the default branch
contents = self.repo_obj.get_contents(".pr_agent.toml").decoded_content
return contents
except Exception:
Expand Down
48 changes: 48 additions & 0 deletions tests/unittest/test_cli_config_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

from pr_agent import cli


def test_set_parser_supports_config_branch_flag():
args = cli.set_parser().parse_args(["--pr_url=https://github.com/a/b/pull/1", "--config-branch", "feature", "review"])
assert args.config_branch == "feature"


def test_run_sets_config_branch_from_cli_flag():
fake_settings = SimpleNamespace(
litellm={},
set=MagicMock(),
)

async def fake_handle_request(*_args, **_kwargs):
return True

with patch("pr_agent.cli.get_settings", return_value=fake_settings), patch(
"pr_agent.cli.PRAgent",
return_value=SimpleNamespace(handle_request=fake_handle_request),
):
cli.run(inargs=["--pr_url=https://github.com/a/b/pull/1", "--config-branch", "feature", "review"])

fake_settings.set.assert_any_call("CONFIG.CONFIG_BRANCH", "feature")


def test_run_sets_config_branch_from_env_var():
fake_settings = SimpleNamespace(
litellm={},
set=MagicMock(),
)

async def fake_handle_request(*_args, **_kwargs):
return True

with patch.dict("os.environ", {"PR_AGENT_CONFIG_BRANCH": "env-branch"}, clear=False), patch(
"pr_agent.cli.get_settings",
return_value=fake_settings,
), patch(
"pr_agent.cli.PRAgent",
return_value=SimpleNamespace(handle_request=fake_handle_request),
):
cli.run(inargs=["--pr_url=https://github.com/a/b/pull/1", "review"])

fake_settings.set.assert_any_call("CONFIG.CONFIG_BRANCH", "env-branch")
59 changes: 59 additions & 0 deletions tests/unittest/test_github_provider_repo_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

from github import GithubException

from pr_agent.git_providers.github_provider import GithubProvider


def _provider_with_repo(repo_obj):
provider = GithubProvider.__new__(GithubProvider)
provider.repo_obj = repo_obj
return provider


def test_get_repo_settings_uses_config_branch_from_settings():
repo_obj = MagicMock()
repo_obj.get_contents.return_value = SimpleNamespace(decoded_content=b"[config]\nmodel='x'")
provider = _provider_with_repo(repo_obj)

with patch("pr_agent.git_providers.github_provider.get_settings") as mock_settings:
mock_settings.return_value.get.return_value = "feature-config"
settings = provider.get_repo_settings()

assert settings == b"[config]\nmodel='x'"
repo_obj.get_contents.assert_called_once_with(".pr_agent.toml", ref="feature-config")


def test_get_repo_settings_falls_back_to_default_branch_on_missing_file_in_config_branch():
repo_obj = MagicMock()
repo_obj.get_contents.side_effect = [
GithubException(404, {"message": "Not Found"}, None),
SimpleNamespace(decoded_content=b"[config]\nmodel='default'"),
]
provider = _provider_with_repo(repo_obj)

with patch("pr_agent.git_providers.github_provider.get_settings") as mock_settings:
mock_settings.return_value.get.return_value = "feature-config"
settings = provider.get_repo_settings()

assert settings == b"[config]\nmodel='default'"
assert repo_obj.get_contents.call_args_list[0].kwargs == {"ref": "feature-config"}
assert repo_obj.get_contents.call_args_list[1].kwargs == {}


def test_get_repo_settings_uses_env_var_when_settings_are_missing():
repo_obj = MagicMock()
repo_obj.get_contents.return_value = SimpleNamespace(decoded_content=b"[config]\nmodel='env'")
provider = _provider_with_repo(repo_obj)

with patch("pr_agent.git_providers.github_provider.get_settings") as mock_settings, patch.dict(
"os.environ",
{"PR_AGENT_CONFIG_BRANCH": "env-branch"},
clear=False,
):
mock_settings.return_value.get.return_value = None
settings = provider.get_repo_settings()

assert settings == b"[config]\nmodel='env'"
repo_obj.get_contents.assert_called_once_with(".pr_agent.toml", ref="env-branch")
Loading