From 29be2e114455f1dfac256156ed89e1826133d4d6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 4 Feb 2026 17:26:24 +0000 Subject: [PATCH 1/2] Add support for bisecting JAX workflows --- .../src/culprit_finder/culprit_finder.py | 37 +++++++++++++++++ .../src/culprit_finder/github_client.py | 40 ++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/culprit_finder/src/culprit_finder/culprit_finder.py b/culprit_finder/src/culprit_finder/culprit_finder.py index 156927bf..d6a9e1fc 100644 --- a/culprit_finder/src/culprit_finder/culprit_finder.py +++ b/culprit_finder/src/culprit_finder/culprit_finder.py @@ -18,6 +18,23 @@ CULPRIT_FINDER_WORKFLOW_NAME = "culprit_finder.yml" +# Configuration for projects that require special handling for external dependencies. +# Some projects (e.g., JAX) depend on the HEAD of another repository (e.g., XLA). +# When bisecting historical commits, running them against the *current* HEAD of the dependency +# often causes build failures unrelated to the regression being investigated. +# +# This map defines how to "time-travel" for these dependencies: +# - dependency_repo: The external repository to look up. +# - input_name: The workflow input variable to set with the pinned commit hash. +# - workflows: The specific workflows where this logic should apply. +PROJECT_CONFIG = { + "jax-ml/jax": { + "dependency_repo": "openxla/xla", + "input_name": "xla-commit", + "workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"], + }, +} + class CulpritFinder: """Culprit finder class to find the culprit commit for a GitHub workflow.""" @@ -185,6 +202,26 @@ def _test_commit( ) previous_run_id = previous_run.id if previous_run else None + if self._repo in PROJECT_CONFIG and self._workflow_file in PROJECT_CONFIG[self._repo]["workflows"]: + config = PROJECT_CONFIG[self._repo] + logging.info("Project %s matched special case config", self._repo) + + # Get date of the commit we are testing + commit_details = self._gh_client.get_commit(commit_sha) + commit_date = commit_details.commit.committer.date + + # Find dependency commit at that time + dep_repo = config["dependency_repo"] + logging.info("Looking up dependency commit for %s at %s", dep_repo, commit_date) + dep_commit = self._gh_client.get_last_commit_before(dep_repo, commit_date) + + if dep_commit: + input_name = config["input_name"] + logging.info("Pinning %s to %s", input_name, dep_commit.sha) + inputs[input_name] = dep_commit.sha + else: + logging.warning("Could not find matching commit for %s at %s", dep_repo, commit_date) + self._gh_client.trigger_workflow( workflow_to_trigger, branch_name, diff --git a/culprit_finder/src/culprit_finder/github_client.py b/culprit_finder/src/culprit_finder/github_client.py index bc89275a..d3fb7687 100644 --- a/culprit_finder/src/culprit_finder/github_client.py +++ b/culprit_finder/src/culprit_finder/github_client.py @@ -29,7 +29,8 @@ def __init__(self, repo: str, token: str): repo: The GitHub repository in 'owner/repo' format. token: The GitHub access token for authentication. """ - self._repo = github.Github(auth=github.Auth.Token(token)).get_repo(repo, lazy=True) + self._gh = github.Github(auth=github.Auth.Token(token)) + self._repo = self._gh.get_repo(repo, lazy=True) def compare_commits(self, base_sha: str, head_sha: str) -> list[Commit]: """ @@ -336,6 +337,43 @@ def get_run_jobs(self, run_id: str | int) -> list[WorkflowJob]: run = self.get_run(str(run_id)) return list(run.jobs()) + def get_commit(self, sha: str) -> Commit: + """ + Gets details of a specific commit. + + Args: + sha: The SHA of the commit to retrieve. + + Returns: + A Commit object. + """ + return self._repo.get_commit(sha) + + def get_last_commit_before( + self, repo_name: str, date: str, branch: str = "main" + ) -> Commit | None: + """ + Finds the latest commit in a repository that is older than or equal to a given date. + + Args: + repo_name: The GitHub repository to search in 'owner/repo' format. + date: The ISO 8601 date string to search before (inclusive). + branch: The branch to search on. + + Returns: + The matching Commit object, or None if not found. + """ + target_repo = self._gh.get_repo(repo_name, lazy=True) + # until: "Only commits before this date will be returned." + # We want the *latest* commit before this date. + # get_commits returns in reverse chronological order (newest first). + # so the first one returned with `until=date` should be the one we want. + commits = target_repo.get_commits(sha=branch, until=date) + + if commits.totalCount > 0: + return commits[0] + return None + def get_github_token() -> str | None: """Retrieves the GitHub access token from the environment or from the the GitHub CLI if not present. From 852781af33287e54f7d3c2d0af5e529d0e24e003 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Feb 2026 18:00:16 +0000 Subject: [PATCH 2/2] Add unit tests + minor performance/readability tweaks --- .../src/culprit_finder/culprit_finder.py | 7 +++ .../src/culprit_finder/github_client.py | 24 +++---- culprit_finder/tests/test_culprit_finder.py | 63 +++++++++++++++++++ culprit_finder/tests/test_github_client.py | 43 +++++++++++++ 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/culprit_finder/src/culprit_finder/culprit_finder.py b/culprit_finder/src/culprit_finder/culprit_finder.py index d6a9e1fc..627742f7 100644 --- a/culprit_finder/src/culprit_finder/culprit_finder.py +++ b/culprit_finder/src/culprit_finder/culprit_finder.py @@ -33,6 +33,11 @@ "input_name": "xla-commit", "workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"], }, + "google-ml-infra/jax-fork": { + "dependency_repo": "openxla/xla", + "input_name": "xla-commit", + "workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"], + }, } @@ -208,6 +213,8 @@ def _test_commit( # Get date of the commit we are testing commit_details = self._gh_client.get_commit(commit_sha) + # PyGithub returns naive datetime objects in UTC. + # get_last_commit_before's `until` parameter natively handles this datetime object. commit_date = commit_details.commit.committer.date # Find dependency commit at that time diff --git a/culprit_finder/src/culprit_finder/github_client.py b/culprit_finder/src/culprit_finder/github_client.py index d3fb7687..bf58d31b 100644 --- a/culprit_finder/src/culprit_finder/github_client.py +++ b/culprit_finder/src/culprit_finder/github_client.py @@ -2,12 +2,13 @@ Module for interacting with the GitHub API via PyGithub. """ -import logging import os import re -import time -from typing import Optional import subprocess +import logging +import datetime +from typing import Optional +import time import github from github.Commit import Commit @@ -350,29 +351,28 @@ def get_commit(self, sha: str) -> Commit: return self._repo.get_commit(sha) def get_last_commit_before( - self, repo_name: str, date: str, branch: str = "main" + self, repo_name: str, date: str | datetime.datetime, branch: str = "main" ) -> Commit | None: """ Finds the latest commit in a repository that is older than or equal to a given date. Args: repo_name: The GitHub repository to search in 'owner/repo' format. - date: The ISO 8601 date string to search before (inclusive). + date: The ISO 8601 date string or datetime object to search before (inclusive). branch: The branch to search on. Returns: The matching Commit object, or None if not found. """ target_repo = self._gh.get_repo(repo_name, lazy=True) - # until: "Only commits before this date will be returned." - # We want the *latest* commit before this date. - # get_commits returns in reverse chronological order (newest first). - # so the first one returned with `until=date` should be the one we want. + # PyGithub's get_commits returns a PaginatedList in reverse chronological order. + # We use `until` to filter commits before the date. commits = target_repo.get_commits(sha=branch, until=date) - if commits.totalCount > 0: - return commits[0] - return None + # Calling `totalCount` on a PaginatedList forces PyGithub to fetch the entire + # pagination graph just to get the count. This is a massive API overhead. + # Instead, we just take the first item from the iterator, which only fetches the first page. + return next(iter(commits), None) def get_github_token() -> str | None: diff --git a/culprit_finder/tests/test_culprit_finder.py b/culprit_finder/tests/test_culprit_finder.py index db242e2b..f8caf620 100644 --- a/culprit_finder/tests/test_culprit_finder.py +++ b/culprit_finder/tests/test_culprit_finder.py @@ -125,6 +125,69 @@ def test_test_commit_success(mocker, finder, mock_gh_client): ) +@pytest.mark.parametrize("has_culprit_workflow", [True, False]) +def test_test_commit_with_project_config( + mocker, mock_gh_client, has_culprit_workflow, mock_state, mock_state_persister +): + """Tests that _test_commit injects the pinned dependency if the repo matches PROJECT_CONFIG.""" + repo_name = "jax-ml/jax" + workflow_file = "wheel_tests_continuous.yml" + branch = "test-branch" + commit_sha = "sha1" + dep_commit_sha = "xla_sha_123" + + # Create finder with specific repo and workflow + finder = culprit_finder.CulpritFinder( + repo=repo_name, + start_sha="start_sha", + end_sha="end_sha", + workflow_file=workflow_file, + has_culprit_finder_workflow=has_culprit_workflow, + gh_client=mock_gh_client, + state=mock_state, + state_persister=mock_state_persister, + ) + + # Mock completion + mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion") + mock_wait.return_value = factories.create_run( + mocker, head_sha=commit_sha, conclusion="success", status="completed" + ) + mock_gh_client.get_latest_run.return_value = None + + # Mock dependency lookup + mock_commit = mocker.Mock() + mock_commit.commit.committer.date = "2023-01-01T00:00:00Z" + mock_gh_client.get_commit.return_value = mock_commit + + mock_dep_commit = mocker.Mock() + mock_dep_commit.sha = dep_commit_sha + mock_gh_client.get_last_commit_before.return_value = mock_dep_commit + + is_good = finder._test_commit(commit_sha, branch) + + assert is_good is True + mock_gh_client.get_commit.assert_called_once_with(commit_sha) + mock_gh_client.get_last_commit_before.assert_called_once_with( + "openxla/xla", "2023-01-01T00:00:00Z" + ) + + # Determine expected arguments based on configuration + if has_culprit_workflow: + expected_workflow = CULPRIT_WORKFLOW + expected_inputs = {"workflow-to-debug": workflow_file, "xla-commit": dep_commit_sha} + else: + expected_workflow = workflow_file + expected_inputs = {"xla-commit": dep_commit_sha} + + mock_gh_client.trigger_workflow.assert_called_once_with( + expected_workflow, + branch, + expected_inputs, + ) + + + def test_test_commit_failure(mocker, finder, mock_gh_client): """Tests that _test_commit returns False if the workflow fails.""" mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion") diff --git a/culprit_finder/tests/test_github_client.py b/culprit_finder/tests/test_github_client.py index 1dab1c9e..a53928fc 100644 --- a/culprit_finder/tests/test_github_client.py +++ b/culprit_finder/tests/test_github_client.py @@ -295,3 +295,46 @@ def test_find_previous_successful_job_run_not_found(mocker): with pytest.raises(ValueError, match="No previous successful run found for job"): client.find_previous_successful_job_run(failed_run, "my_job") + + +def test_get_commit(mocker): + """Tests getting a specific commit.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_repo = client._repo + expected_commit = mocker.Mock() + mock_repo.get_commit.return_value = expected_commit + + result = client.get_commit("mock_sha") + assert result == expected_commit + mock_repo.get_commit.assert_called_once_with("mock_sha") + + +def test_get_last_commit_before_found(mocker): + """Tests finding the last commit before a date when commits exist.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_gh = client._gh + mock_target_repo = mocker.Mock() + mock_gh.get_repo.return_value = mock_target_repo + + expected_commit = mocker.Mock() + mock_commits = [expected_commit, mocker.Mock()] + mock_target_repo.get_commits.return_value = mock_commits + + result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z") + + assert result == expected_commit + mock_target_repo.get_commits.assert_called_once_with(sha="main", until="2023-01-01T00:00:00Z") + + +def test_get_last_commit_before_not_found(mocker): + """Tests getting last commit before a date when no commits exist.""" + client = github_client.GithubClient("owner/repo", token="test-token") + mock_gh = client._gh + mock_target_repo = mocker.Mock() + mock_gh.get_repo.return_value = mock_target_repo + + mock_target_repo.get_commits.return_value = [] + + result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z") + + assert result is None