diff --git a/culprit_finder/src/culprit_finder/culprit_finder.py b/culprit_finder/src/culprit_finder/culprit_finder.py index 2115b307..47219004 100644 --- a/culprit_finder/src/culprit_finder/culprit_finder.py +++ b/culprit_finder/src/culprit_finder/culprit_finder.py @@ -18,6 +18,28 @@ CULPRIT_FINDER_WORKFLOW_NAME = "culprit_finder.yml" +# Configuration for projects that require special handling for external dependencies. +# Some projects (e.g., JAX) depend on the HEAD of another repository (e.g., XLA). +# When bisecting historical commits, running them against the *current* HEAD of the dependency +# often causes build failures unrelated to the regression being investigated. +# +# This map defines how to "time-travel" for these dependencies: +# - dependency_repo: The external repository to look up. +# - input_name: The workflow input variable to set with the pinned commit hash. +# - workflows: The specific workflows where this logic should apply. +PROJECT_CONFIG = { + "jax-ml/jax": { + "dependency_repo": "openxla/xla", + "input_name": "xla-commit", + "workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"], + }, + "google-ml-infra/jax-fork": { + "dependency_repo": "openxla/xla", + "input_name": "xla-commit", + "workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"], + }, +} + class CulpritFinder: """Culprit finder class to find the culprit commit for a GitHub workflow.""" @@ -196,6 +218,28 @@ def _test_commit( ) previous_run_id = previous_run.id if previous_run else None + if self._repo in PROJECT_CONFIG and self._workflow_file in PROJECT_CONFIG[self._repo]["workflows"]: + config = PROJECT_CONFIG[self._repo] + logging.info("Project %s matched special case config", self._repo) + + # Get date of the commit we are testing + commit_details = self._gh_client.get_commit(commit_sha) + # PyGithub returns naive datetime objects in UTC. + # get_last_commit_before's `until` parameter natively handles this datetime object. + commit_date = commit_details.commit.committer.date + + # Find dependency commit at that time + dep_repo = config["dependency_repo"] + logging.info("Looking up dependency commit for %s at %s", dep_repo, commit_date) + dep_commit = self._gh_client.get_last_commit_before(dep_repo, commit_date) + + if dep_commit: + input_name = config["input_name"] + logging.info("Pinning %s to %s", input_name, dep_commit.sha) + inputs[input_name] = dep_commit.sha + else: + logging.warning("Could not find matching commit for %s at %s", dep_repo, commit_date) + self._gh_client.trigger_workflow( workflow_to_trigger, branch_name, diff --git a/culprit_finder/src/culprit_finder/github_client.py b/culprit_finder/src/culprit_finder/github_client.py index a72477cb..502552ba 100644 --- a/culprit_finder/src/culprit_finder/github_client.py +++ b/culprit_finder/src/culprit_finder/github_client.py @@ -2,12 +2,13 @@ Module for interacting with the GitHub API via PyGithub. """ -import logging import os import re -import time -from typing import Optional import subprocess +import logging +import datetime +from typing import Optional +import time import github from github.Commit import Commit @@ -29,7 +30,8 @@ def __init__(self, repo: str, token: str): repo: The GitHub repository in 'owner/repo' format. token: The GitHub access token for authentication. """ - self._repo = github.Github(auth=github.Auth.Token(token)).get_repo(repo, lazy=True) + self._gh = github.Github(auth=github.Auth.Token(token)) + self._repo = self._gh.get_repo(repo, lazy=True) def compare_commits(self, base_sha: str, head_sha: str) -> list[Commit]: """ @@ -341,6 +343,42 @@ def get_run_jobs(self, run_id: str | int) -> list[WorkflowJob]: run = self.get_run(str(run_id)) return list(run.jobs()) + def get_commit(self, sha: str) -> Commit: + """ + Gets details of a specific commit. + + Args: + sha: The SHA of the commit to retrieve. + + Returns: + A Commit object. + """ + return self._repo.get_commit(sha) + + def get_last_commit_before( + self, repo_name: str, date: str | datetime.datetime, branch: str = "main" + ) -> Commit | None: + """ + Finds the latest commit in a repository that is older than or equal to a given date. + + Args: + repo_name: The GitHub repository to search in 'owner/repo' format. + date: The ISO 8601 date string or datetime object to search before (inclusive). + branch: The branch to search on. + + Returns: + The matching Commit object, or None if not found. + """ + target_repo = self._gh.get_repo(repo_name, lazy=True) + # PyGithub's get_commits returns a PaginatedList in reverse chronological order. + # We use `until` to filter commits before the date. + commits = target_repo.get_commits(sha=branch, until=date) + + # Calling `totalCount` on a PaginatedList forces PyGithub to fetch the entire + # pagination graph just to get the count. This is a massive API overhead. + # Instead, we just take the first item from the iterator, which only fetches the first page. + return next(iter(commits), None) + def get_github_token() -> str | None: """Retrieves the GitHub access token from the environment or from the the GitHub CLI if not present. diff --git a/culprit_finder/tests/test_culprit_finder.py b/culprit_finder/tests/test_culprit_finder.py index c0b00983..a333bb59 100644 --- a/culprit_finder/tests/test_culprit_finder.py +++ b/culprit_finder/tests/test_culprit_finder.py @@ -170,6 +170,79 @@ def test_test_commit_outcomes( ) +@pytest.mark.parametrize("has_culprit_workflow", [True, False]) +def test_test_commit_with_project_config( + mocker, mock_gh_client, has_culprit_workflow, mock_state, mock_state_persister +): + """Tests that _test_commit injects the pinned dependency if the repo matches PROJECT_CONFIG.""" + repo_name = "jax-ml/jax" + workflow_file = "wheel_tests_continuous.yml" + branch = "test-branch" + commit_sha = "sha1" + dep_commit_sha = "xla_sha_123" + + # Create finder with specific repo and workflow + finder = culprit_finder.CulpritFinder( + repo=repo_name, + start_sha="start_sha", + end_sha="end_sha", + workflow_file=workflow_file, + has_culprit_finder_workflow=has_culprit_workflow, + gh_client=mock_gh_client, + state=mock_state, + state_persister=mock_state_persister, + ) + + # Mock completion + mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion") + mock_wait.return_value = factories.create_run( + mocker, head_sha=commit_sha, conclusion="success", status="completed" + ) + mock_gh_client.get_latest_run.return_value = None + + # Mock dependency lookup + mock_commit = mocker.Mock() + mock_commit.commit.committer.date = "2023-01-01T00:00:00Z" + mock_gh_client.get_commit.return_value = mock_commit + + mock_dep_commit = mocker.Mock() + mock_dep_commit.sha = dep_commit_sha + mock_gh_client.get_last_commit_before.return_value = mock_dep_commit + + is_good = finder._test_commit(commit_sha, branch) + + assert is_good is True + mock_gh_client.get_commit.assert_called_once_with(commit_sha) + mock_gh_client.get_last_commit_before.assert_called_once_with( + "openxla/xla", "2023-01-01T00:00:00Z" + ) + + # Determine expected arguments based on configuration + if has_culprit_workflow: + expected_workflow = CULPRIT_WORKFLOW + expected_inputs = {"workflow-to-debug": workflow_file, "xla-commit": dep_commit_sha} + else: + expected_workflow = workflow_file + expected_inputs = {"xla-commit": dep_commit_sha} + + mock_gh_client.trigger_workflow.assert_called_once_with( + expected_workflow, + branch, + expected_inputs, + ) + + +def test_test_commit_failure(mocker, finder, mock_gh_client): + """Tests that _test_commit returns False if the workflow fails.""" + mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion") + mock_wait.return_value = factories.create_run(mocker, "sha", "completed", "failure") + + # Mock get_latest_run to return None for the "previous run" check + mock_gh_client.get_latest_run.return_value = None + + assert finder._test_commit("sha", "branch") is False + + @pytest.mark.parametrize("has_culprit_workflow", [True, False]) def test_test_commit_with_specific_job( mocker, finder_factory, mock_gh_client, has_culprit_workflow diff --git a/culprit_finder/tests/test_github_client.py b/culprit_finder/tests/test_github_client.py index 1dab1c9e..a53928fc 100644 --- a/culprit_finder/tests/test_github_client.py +++ b/culprit_finder/tests/test_github_client.py @@ -295,3 +295,46 @@ def test_find_previous_successful_job_run_not_found(mocker): with pytest.raises(ValueError, match="No previous successful run found for job"): client.find_previous_successful_job_run(failed_run, "my_job") + + +def test_get_commit(mocker): + """Tests getting a specific commit.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_repo = client._repo + expected_commit = mocker.Mock() + mock_repo.get_commit.return_value = expected_commit + + result = client.get_commit("mock_sha") + assert result == expected_commit + mock_repo.get_commit.assert_called_once_with("mock_sha") + + +def test_get_last_commit_before_found(mocker): + """Tests finding the last commit before a date when commits exist.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_gh = client._gh + mock_target_repo = mocker.Mock() + mock_gh.get_repo.return_value = mock_target_repo + + expected_commit = mocker.Mock() + mock_commits = [expected_commit, mocker.Mock()] + mock_target_repo.get_commits.return_value = mock_commits + + result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z") + + assert result == expected_commit + mock_target_repo.get_commits.assert_called_once_with(sha="main", until="2023-01-01T00:00:00Z") + + +def test_get_last_commit_before_not_found(mocker): + """Tests getting last commit before a date when no commits exist.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_gh = client._gh + mock_target_repo = mocker.Mock() + mock_gh.get_repo.return_value = mock_target_repo + + mock_target_repo.get_commits.return_value = [] + + result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z") + + assert result is None