Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions culprit_finder/src/culprit_finder/culprit_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@

CULPRIT_FINDER_WORKFLOW_NAME = "culprit_finder.yml"

# Configuration for projects that require special handling for external dependencies.
# Some projects (e.g., JAX) depend on the HEAD of another repository (e.g., XLA).
# When bisecting historical commits, running them against the *current* HEAD of the dependency
# often causes build failures unrelated to the regression being investigated.
#
# This map defines how to "time-travel" for these dependencies:
# - dependency_repo: The external repository to look up.
# - input_name: The workflow input variable to set with the pinned commit hash.
# - workflows: The specific workflows where this logic should apply.
PROJECT_CONFIG = {
"jax-ml/jax": {
"dependency_repo": "openxla/xla",
"input_name": "xla-commit",
"workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"],
},
"google-ml-infra/jax-fork": {
"dependency_repo": "openxla/xla",
"input_name": "xla-commit",
"workflows": ["wheel_tests_continuous.yml", "build_artifacts.yml"],
},
}


class CulpritFinder:
"""Culprit finder class to find the culprit commit for a GitHub workflow."""
Expand Down Expand Up @@ -196,6 +218,28 @@ def _test_commit(
)
previous_run_id = previous_run.id if previous_run else None

if self._repo in PROJECT_CONFIG and self._workflow_file in PROJECT_CONFIG[self._repo]["workflows"]:
config = PROJECT_CONFIG[self._repo]
logging.info("Project %s matched special case config", self._repo)

# Get date of the commit we are testing
commit_details = self._gh_client.get_commit(commit_sha)
# PyGithub returns naive datetime objects in UTC.
# get_last_commit_before's `until` parameter natively handles this datetime object.
commit_date = commit_details.commit.committer.date

# Find dependency commit at that time
dep_repo = config["dependency_repo"]
logging.info("Looking up dependency commit for %s at %s", dep_repo, commit_date)
dep_commit = self._gh_client.get_last_commit_before(dep_repo, commit_date)

if dep_commit:
input_name = config["input_name"]
logging.info("Pinning %s to %s", input_name, dep_commit.sha)
inputs[input_name] = dep_commit.sha
else:
logging.warning("Could not find matching commit for %s at %s", dep_repo, commit_date)

self._gh_client.trigger_workflow(
workflow_to_trigger,
branch_name,
Expand Down
46 changes: 42 additions & 4 deletions culprit_finder/src/culprit_finder/github_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Module for interacting with the GitHub API via PyGithub.
"""

import logging
import os
import re
import time
from typing import Optional
import subprocess
import logging
import datetime
from typing import Optional
import time

import github
from github.Commit import Commit
Expand All @@ -29,7 +30,8 @@ def __init__(self, repo: str, token: str):
repo: The GitHub repository in 'owner/repo' format.
token: The GitHub access token for authentication.
"""
self._repo = github.Github(auth=github.Auth.Token(token)).get_repo(repo, lazy=True)
self._gh = github.Github(auth=github.Auth.Token(token))
self._repo = self._gh.get_repo(repo, lazy=True)

def compare_commits(self, base_sha: str, head_sha: str) -> list[Commit]:
"""
Expand Down Expand Up @@ -341,6 +343,42 @@ def get_run_jobs(self, run_id: str | int) -> list[WorkflowJob]:
run = self.get_run(str(run_id))
return list(run.jobs())

def get_commit(self, sha: str) -> Commit:
"""
Gets details of a specific commit.

Args:
sha: The SHA of the commit to retrieve.

Returns:
A Commit object.
"""
return self._repo.get_commit(sha)

def get_last_commit_before(
self, repo_name: str, date: str | datetime.datetime, branch: str = "main"
) -> Commit | None:
"""
Finds the latest commit in a repository that is older than or equal to a given date.

Args:
repo_name: The GitHub repository to search in 'owner/repo' format.
date: The ISO 8601 date string or datetime object to search before (inclusive).
branch: The branch to search on.

Returns:
The matching Commit object, or None if not found.
"""
target_repo = self._gh.get_repo(repo_name, lazy=True)
# PyGithub's get_commits returns a PaginatedList in reverse chronological order.
# We use `until` to filter commits before the date.
commits = target_repo.get_commits(sha=branch, until=date)

# Calling `totalCount` on a PaginatedList forces PyGithub to fetch the entire
# pagination graph just to get the count. This is a massive API overhead.
# Instead, we just take the first item from the iterator, which only fetches the first page.
return next(iter(commits), None)


def get_github_token() -> str | None:
"""Retrieves the GitHub access token from the environment or from the the GitHub CLI if not present.
Expand Down
73 changes: 73 additions & 0 deletions culprit_finder/tests/test_culprit_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,79 @@ def test_test_commit_outcomes(
)


@pytest.mark.parametrize("has_culprit_workflow", [True, False])
def test_test_commit_with_project_config(
mocker, mock_gh_client, has_culprit_workflow, mock_state, mock_state_persister
):
"""Tests that _test_commit injects the pinned dependency if the repo matches PROJECT_CONFIG."""
repo_name = "jax-ml/jax"
workflow_file = "wheel_tests_continuous.yml"
branch = "test-branch"
commit_sha = "sha1"
dep_commit_sha = "xla_sha_123"

# Create finder with specific repo and workflow
finder = culprit_finder.CulpritFinder(
repo=repo_name,
start_sha="start_sha",
end_sha="end_sha",
workflow_file=workflow_file,
has_culprit_finder_workflow=has_culprit_workflow,
gh_client=mock_gh_client,
state=mock_state,
state_persister=mock_state_persister,
)

# Mock completion
mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion")
mock_wait.return_value = factories.create_run(
mocker, head_sha=commit_sha, conclusion="success", status="completed"
)
mock_gh_client.get_latest_run.return_value = None

# Mock dependency lookup
mock_commit = mocker.Mock()
mock_commit.commit.committer.date = "2023-01-01T00:00:00Z"
mock_gh_client.get_commit.return_value = mock_commit

mock_dep_commit = mocker.Mock()
mock_dep_commit.sha = dep_commit_sha
mock_gh_client.get_last_commit_before.return_value = mock_dep_commit

is_good = finder._test_commit(commit_sha, branch)

assert is_good is True
mock_gh_client.get_commit.assert_called_once_with(commit_sha)
mock_gh_client.get_last_commit_before.assert_called_once_with(
"openxla/xla", "2023-01-01T00:00:00Z"
)

# Determine expected arguments based on configuration
if has_culprit_workflow:
expected_workflow = CULPRIT_WORKFLOW
expected_inputs = {"workflow-to-debug": workflow_file, "xla-commit": dep_commit_sha}
else:
expected_workflow = workflow_file
expected_inputs = {"xla-commit": dep_commit_sha}

mock_gh_client.trigger_workflow.assert_called_once_with(
expected_workflow,
branch,
expected_inputs,
)


def test_test_commit_failure(mocker, finder, mock_gh_client):
"""Tests that _test_commit returns False if the workflow fails."""
mock_wait = mocker.patch.object(finder, "_wait_for_workflow_completion")
mock_wait.return_value = factories.create_run(mocker, "sha", "completed", "failure")

# Mock get_latest_run to return None for the "previous run" check
mock_gh_client.get_latest_run.return_value = None

assert finder._test_commit("sha", "branch") is False


@pytest.mark.parametrize("has_culprit_workflow", [True, False])
def test_test_commit_with_specific_job(
mocker, finder_factory, mock_gh_client, has_culprit_workflow
Expand Down
43 changes: 43 additions & 0 deletions culprit_finder/tests/test_github_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,46 @@ def test_find_previous_successful_job_run_not_found(mocker):

with pytest.raises(ValueError, match="No previous successful run found for job"):
client.find_previous_successful_job_run(failed_run, "my_job")


def test_get_commit(mocker):
"""Tests getting a specific commit."""
client = github_client.GithubClient("owner/repo", token="test-token")
mock_repo = client._repo
expected_commit = mocker.Mock()
mock_repo.get_commit.return_value = expected_commit

result = client.get_commit("mock_sha")
assert result == expected_commit
mock_repo.get_commit.assert_called_once_with("mock_sha")


def test_get_last_commit_before_found(mocker):
"""Tests finding the last commit before a date when commits exist."""
client = github_client.GithubClient("owner/repo", token="test-token")
mock_gh = client._gh
mock_target_repo = mocker.Mock()
mock_gh.get_repo.return_value = mock_target_repo

expected_commit = mocker.Mock()
mock_commits = [expected_commit, mocker.Mock()]
mock_target_repo.get_commits.return_value = mock_commits

result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z")

assert result == expected_commit
mock_target_repo.get_commits.assert_called_once_with(sha="main", until="2023-01-01T00:00:00Z")


def test_get_last_commit_before_not_found(mocker):
"""Tests getting last commit before a date when no commits exist."""
client = github_client.GithubClient("owner/repo", token="test-token")
mock_gh = client._gh
mock_target_repo = mocker.Mock()
mock_gh.get_repo.return_value = mock_target_repo

mock_target_repo.get_commits.return_value = []

result = client.get_last_commit_before("other/repo", "2023-01-01T00:00:00Z")

assert result is None