Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cmd/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/github/gh-stack/internal/config"
"github.com/github/gh-stack/internal/git"
"github.com/github/gh-stack/internal/github"
"github.com/github/gh-stack/internal/stack"
"github.com/stretchr/testify/assert"
)
Expand All @@ -32,6 +33,7 @@ func TestMerge_NoPullRequest(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -65,6 +67,7 @@ func TestMerge_AlreadyMerged(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -104,6 +107,7 @@ func TestMerge_FullyMergedStack(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -136,6 +140,7 @@ func TestMerge_OnTrunk(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -168,6 +173,19 @@ func TestMerge_NonInteractive_PrintsURL(t *testing.T) {

// NewTestConfig is non-interactive (piped output), so no confirm prompt.
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
if number == 42 {
return &github.PullRequest{
Number: 42,
ID: "PR_42",
URL: "https://github.com/owner/repo/pull/42",
State: "OPEN",
}, nil
}
return nil, nil
},
}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -196,6 +214,7 @@ func TestMerge_NoArgs(t *testing.T) {
defer restore()

cfg, _, _ := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -229,6 +248,17 @@ func TestMerge_ByPRNumber(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
switch number {
case 42:
return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil
case 43:
return &github.PullRequest{Number: 43, URL: "https://github.com/owner/repo/pull/43", State: "OPEN"}, nil
}
return nil, nil
},
}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -261,6 +291,14 @@ func TestMerge_ByPRURL(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
if number == 42 {
return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil
}
return nil, nil
},
}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down Expand Up @@ -293,6 +331,14 @@ func TestMerge_ByBranchName(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
if number == 42 {
return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil
}
return nil, nil
},
}
cmd := MergeCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down
12 changes: 11 additions & 1 deletion cmd/push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,17 @@ func TestPush_NoSubmitHintWhenPRsExist(t *testing.T) {
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
switch number {
case 10:
return &github.PullRequest{Number: 10, State: "OPEN", HeadRefName: "b1"}, nil
case 11:
return &github.PullRequest{Number: 11, State: "OPEN", HeadRefName: "b2"}, nil
}
return nil, nil
},
}
cmd := PushCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
Expand Down
27 changes: 22 additions & 5 deletions cmd/submit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,12 @@ func TestSubmit_SkipsMergedBranches(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
return &github.PullRequest{Number: 2, URL: "https://github.com/owner/repo/pull/2"}, nil
// Only return an OPEN PR for the active branch (b2).
// Merged branches (b1, b3) should have no open PR.
if branch == "b2" {
return &github.PullRequest{Number: 2, URL: "https://github.com/owner/repo/pull/2", State: "OPEN"}, nil
}
return nil, nil
},
}
cmd := SubmitCmd(cfg)
Expand Down Expand Up @@ -1026,7 +1031,7 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) {
tmpDir := t.TempDir()
writeStackFile(t, tmpDir, s)

listStacksCalled := false
listStacksCallCount := 0
mock := newSubmitMock(tmpDir, "b1")
mock.PushFn = func(string, []string, bool, bool) error { return nil }
restore := git.SetOps(mock)
Expand All @@ -1035,8 +1040,17 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
ListStacksFn: func() ([]github.RemoteStack, error) {
listStacksCalled = true
return nil, &api.HTTPError{StatusCode: 404, Message: "Not Found"}
listStacksCallCount++
return []github.RemoteStack{{ID: 42, PullRequests: []int{10, 11}}}, nil
},
FindPRByNumberFn: func(number int) (*github.PullRequest, error) {
switch number {
case 10:
return &github.PullRequest{Number: 10, URL: "https://github.com/o/r/pull/10", HeadRefName: "b1", State: "OPEN"}, nil
case 11:
return &github.PullRequest{Number: 11, URL: "https://github.com/o/r/pull/11", HeadRefName: "b2", State: "OPEN"}, nil
}
return nil, nil
},
FindPRForBranchFn: func(string) (*github.PullRequest, error) {
return &github.PullRequest{Number: 10, URL: "https://github.com/o/r/pull/10"}, nil
Expand All @@ -1054,5 +1068,8 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) {
_, _ = io.ReadAll(errR)

assert.NoError(t, err)
assert.False(t, listStacksCalled, "ListStacks should not be called when stack ID already exists")
// ListStacks is called by syncStackPRs (remote sync), but NOT by the
// preflight check. Two syncStackPRs calls happen in submit (before and
// after PR creation), so expect exactly 2 ListStacks calls.
assert.Equal(t, 2, listStacksCallCount, "ListStacks should only be called by syncStackPRs, not by the preflight check")
}
108 changes: 105 additions & 3 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cli/go-gh/v2/pkg/prompter"
"github.com/github/gh-stack/internal/config"
"github.com/github/gh-stack/internal/git"
"github.com/github/gh-stack/internal/github"
"github.com/github/gh-stack/internal/stack"
)

Expand Down Expand Up @@ -231,27 +232,126 @@ func resolveStack(sf *stack.StackFile, branch string, cfg *config.Config) (*stac
}

// syncStackPRs discovers and updates pull request metadata for branches in a stack.
// For each branch, it queries GitHub for the most recent PR and updates the
// PullRequestRef including merge status. Branches with already-merged PRs are skipped.
//
// When the stack has a remote ID, the stack API is the source of truth: the
// authoritative PR list is fetched from the server and matched to local
// branches by head branch name. PRs remain associated even if closed.
//
// When no remote stack exists, branch-name-based discovery is used:
//
// 1. No tracked PR — look for an OPEN PR by head branch name.
// 2. Tracked PR (not merged) — refresh status by number; if closed,
// clear the association and fall through to path 1.
// 3. Tracked PR (merged) — skip; the merged state is final.
//
// The transient Queued flag is also populated from the API response.
func syncStackPRs(cfg *config.Config, s *stack.Stack) {
client, err := cfg.GitHubClient()
if err != nil {
return
}

// When the stack has a remote ID, the stack API is the source of truth.
if s.ID != "" {
if syncStackPRsFromRemote(client, s) {
return
}
}

// No remote stack (or remote sync failed) — local discovery.
for i := range s.Branches {
b := &s.Branches[i]

if b.IsMerged() {
continue
}

pr, err := client.FindAnyPRForBranch(b.Branch)
if b.PullRequest != nil && b.PullRequest.Number != 0 {
// Tracked PR — refresh its state.
pr, err := client.FindPRByNumber(b.PullRequest.Number)
if err != nil {
continue // API error — keep existing tracked PR
}
if pr == nil {
// PR not found — clear stale ref and fall through
// to the open-PR lookup below.
b.PullRequest = nil
b.Queued = false
} else {
b.PullRequest = &stack.PullRequestRef{
Number: pr.Number,
ID: pr.ID,
URL: pr.URL,
Merged: pr.Merged,
}
b.Queued = pr.IsQueued()

// If the PR was closed (not merged), remove the association
// so we fall through to the open-PR lookup below.
if pr.State == "CLOSED" {
b.PullRequest = nil
b.Queued = false
} else {
continue
}
}
}

// No tracked PR (or just cleared) — only adopt OPEN PRs to avoid
// picking up stale merged/closed PRs from a previous use of this
// branch name.
pr, err := client.FindPRForBranch(b.Branch)
if err != nil || pr == nil {
continue
}
b.PullRequest = &stack.PullRequestRef{
Number: pr.Number,
ID: pr.ID,
URL: pr.URL,
}
b.Queued = pr.IsQueued()
}
}

// syncStackPRsFromRemote uses the stack API to sync PR state. The remote
// stack's PR list is the source of truth — PRs stay associated even if
// closed. Returns true if the sync succeeded, false if we should fall
// back to local discovery (e.g. stack not found remotely, API error).
func syncStackPRsFromRemote(client github.ClientOps, s *stack.Stack) bool {
stacks, err := client.ListStacks()
if err != nil {
return false
}

// Find our stack in the remote list.
var remotePRNumbers []int
for _, rs := range stacks {
if strconv.Itoa(rs.ID) == s.ID {
remotePRNumbers = rs.PullRequests
break
}
}
if remotePRNumbers == nil {
return false
}

// Fetch each remote PR's details and index by head branch name.
prByBranch := make(map[string]*github.PullRequest, len(remotePRNumbers))
for _, num := range remotePRNumbers {
pr, err := client.FindPRByNumber(num)
if err != nil || pr == nil {
continue
}
prByBranch[pr.HeadRefName] = pr
}

// Match remote PRs to local branches.
for i := range s.Branches {
b := &s.Branches[i]
pr, ok := prByBranch[b.Branch]
if !ok {
continue
}
b.PullRequest = &stack.PullRequestRef{
Number: pr.Number,
ID: pr.ID,
Expand All @@ -260,6 +360,8 @@ func syncStackPRs(cfg *config.Config, s *stack.Stack) {
}
b.Queued = pr.IsQueued()
}

return true
}

// updateBaseSHAs refreshes the Base and Head SHAs for all active branches
Expand Down
Loading