diff --git a/cmd/merge_test.go b/cmd/merge_test.go index bb59b94..3a5e282 100644 --- a/cmd/merge_test.go +++ b/cmd/merge_test.go @@ -6,6 +6,7 @@ import ( "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" + "github.com/github/gh-stack/internal/github" "github.com/github/gh-stack/internal/stack" "github.com/stretchr/testify/assert" ) @@ -32,6 +33,7 @@ func TestMerge_NoPullRequest(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -65,6 +67,7 @@ func TestMerge_AlreadyMerged(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -104,6 +107,7 @@ func TestMerge_FullyMergedStack(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -136,6 +140,7 @@ func TestMerge_OnTrunk(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -168,6 +173,19 @@ func TestMerge_NonInteractive_PrintsURL(t *testing.T) { // NewTestConfig is non-interactive (piped output), so no confirm prompt. cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + if number == 42 { + return &github.PullRequest{ + Number: 42, + ID: "PR_42", + URL: "https://github.com/owner/repo/pull/42", + State: "OPEN", + }, nil + } + return nil, nil + }, + } cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -196,6 +214,7 @@ func TestMerge_NoArgs(t *testing.T) { defer restore() cfg, _, _ := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -229,6 +248,17 @@ func TestMerge_ByPRNumber(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 42: + return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil + case 43: + return &github.PullRequest{Number: 43, URL: "https://github.com/owner/repo/pull/43", State: "OPEN"}, nil + } + return nil, nil + }, + } cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -261,6 +291,14 @@ func TestMerge_ByPRURL(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + if number == 42 { + return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil + } + return nil, nil + }, + } cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) @@ -293,6 +331,14 @@ func TestMerge_ByBranchName(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + if number == 42 { + return &github.PullRequest{Number: 42, URL: "https://github.com/owner/repo/pull/42", State: "OPEN"}, nil + } + return nil, nil + }, + } cmd := MergeCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) diff --git a/cmd/push_test.go b/cmd/push_test.go index 371b857..ff70e57 100644 --- a/cmd/push_test.go +++ b/cmd/push_test.go @@ -84,7 +84,17 @@ func TestPush_NoSubmitHintWhenPRsExist(t *testing.T) { defer restore() cfg, _, errR := config.NewTestConfig() - cfg.GitHubClientOverride = &github.MockClient{} + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 10: + return &github.PullRequest{Number: 10, State: "OPEN", HeadRefName: "b1"}, nil + case 11: + return &github.PullRequest{Number: 11, State: "OPEN", HeadRefName: "b2"}, nil + } + return nil, nil + }, + } cmd := PushCmd(cfg) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) diff --git a/cmd/submit_test.go b/cmd/submit_test.go index 570baf9..07d67d4 100644 --- a/cmd/submit_test.go +++ b/cmd/submit_test.go @@ -196,7 +196,12 @@ func TestSubmit_SkipsMergedBranches(t *testing.T) { cfg, _, errR := config.NewTestConfig() cfg.GitHubClientOverride = &github.MockClient{ FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { - return &github.PullRequest{Number: 2, URL: "https://github.com/owner/repo/pull/2"}, nil + // Only return an OPEN PR for the active branch (b2). + // Merged branches (b1, b3) should have no open PR. + if branch == "b2" { + return &github.PullRequest{Number: 2, URL: "https://github.com/owner/repo/pull/2", State: "OPEN"}, nil + } + return nil, nil }, } cmd := SubmitCmd(cfg) @@ -1026,7 +1031,7 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) { tmpDir := t.TempDir() writeStackFile(t, tmpDir, s) - listStacksCalled := false + listStacksCallCount := 0 mock := newSubmitMock(tmpDir, "b1") mock.PushFn = func(string, []string, bool, bool) error { return nil } restore := git.SetOps(mock) @@ -1035,8 +1040,17 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) { cfg, _, errR := config.NewTestConfig() cfg.GitHubClientOverride = &github.MockClient{ ListStacksFn: func() ([]github.RemoteStack, error) { - listStacksCalled = true - return nil, &api.HTTPError{StatusCode: 404, Message: "Not Found"} + listStacksCallCount++ + return []github.RemoteStack{{ID: 42, PullRequests: []int{10, 11}}}, nil + }, + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 10: + return &github.PullRequest{Number: 10, URL: "https://github.com/o/r/pull/10", HeadRefName: "b1", State: "OPEN"}, nil + case 11: + return &github.PullRequest{Number: 11, URL: "https://github.com/o/r/pull/11", HeadRefName: "b2", State: "OPEN"}, nil + } + return nil, nil }, FindPRForBranchFn: func(string) (*github.PullRequest, error) { return &github.PullRequest{Number: 10, URL: "https://github.com/o/r/pull/10"}, nil @@ -1054,5 +1068,8 @@ func TestSubmit_PreflightCheck_SkippedWhenStackIDSet(t *testing.T) { _, _ = io.ReadAll(errR) assert.NoError(t, err) - assert.False(t, listStacksCalled, "ListStacks should not be called when stack ID already exists") + // ListStacks is called by syncStackPRs (remote sync), but NOT by the + // preflight check. Two syncStackPRs calls happen in submit (before and + // after PR creation), so expect exactly 2 ListStacks calls. + assert.Equal(t, 2, listStacksCallCount, "ListStacks should only be called by syncStackPRs, not by the preflight check") } diff --git a/cmd/utils.go b/cmd/utils.go index 9a09e48..dc107ea 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -11,6 +11,7 @@ import ( "github.com/cli/go-gh/v2/pkg/prompter" "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" + "github.com/github/gh-stack/internal/github" "github.com/github/gh-stack/internal/stack" ) @@ -231,8 +232,18 @@ func resolveStack(sf *stack.StackFile, branch string, cfg *config.Config) (*stac } // syncStackPRs discovers and updates pull request metadata for branches in a stack. -// For each branch, it queries GitHub for the most recent PR and updates the -// PullRequestRef including merge status. Branches with already-merged PRs are skipped. +// +// When the stack has a remote ID, the stack API is the source of truth: the +// authoritative PR list is fetched from the server and matched to local +// branches by head branch name. PRs remain associated even if closed. +// +// When no remote stack exists, branch-name-based discovery is used: +// +// 1. No tracked PR — look for an OPEN PR by head branch name. +// 2. Tracked PR (not merged) — refresh status by number; if closed, +// clear the association and fall through to path 1. +// 3. Tracked PR (merged) — skip; the merged state is final. +// // The transient Queued flag is also populated from the API response. func syncStackPRs(cfg *config.Config, s *stack.Stack) { client, err := cfg.GitHubClient() @@ -240,6 +251,14 @@ func syncStackPRs(cfg *config.Config, s *stack.Stack) { return } + // When the stack has a remote ID, the stack API is the source of truth. + if s.ID != "" { + if syncStackPRsFromRemote(client, s) { + return + } + } + + // No remote stack (or remote sync failed) — local discovery. for i := range s.Branches { b := &s.Branches[i] @@ -247,11 +266,92 @@ func syncStackPRs(cfg *config.Config, s *stack.Stack) { continue } - pr, err := client.FindAnyPRForBranch(b.Branch) + if b.PullRequest != nil && b.PullRequest.Number != 0 { + // Tracked PR — refresh its state. + pr, err := client.FindPRByNumber(b.PullRequest.Number) + if err != nil { + continue // API error — keep existing tracked PR + } + if pr == nil { + // PR not found — clear stale ref and fall through + // to the open-PR lookup below. + b.PullRequest = nil + b.Queued = false + } else { + b.PullRequest = &stack.PullRequestRef{ + Number: pr.Number, + ID: pr.ID, + URL: pr.URL, + Merged: pr.Merged, + } + b.Queued = pr.IsQueued() + + // If the PR was closed (not merged), remove the association + // so we fall through to the open-PR lookup below. + if pr.State == "CLOSED" { + b.PullRequest = nil + b.Queued = false + } else { + continue + } + } + } + + // No tracked PR (or just cleared) — only adopt OPEN PRs to avoid + // picking up stale merged/closed PRs from a previous use of this + // branch name. + pr, err := client.FindPRForBranch(b.Branch) + if err != nil || pr == nil { + continue + } + b.PullRequest = &stack.PullRequestRef{ + Number: pr.Number, + ID: pr.ID, + URL: pr.URL, + } + b.Queued = pr.IsQueued() + } +} + +// syncStackPRsFromRemote uses the stack API to sync PR state. The remote +// stack's PR list is the source of truth — PRs stay associated even if +// closed. Returns true if the sync succeeded, false if we should fall +// back to local discovery (e.g. stack not found remotely, API error). +func syncStackPRsFromRemote(client github.ClientOps, s *stack.Stack) bool { + stacks, err := client.ListStacks() + if err != nil { + return false + } + + // Find our stack in the remote list. + var remotePRNumbers []int + for _, rs := range stacks { + if strconv.Itoa(rs.ID) == s.ID { + remotePRNumbers = rs.PullRequests + break + } + } + if remotePRNumbers == nil { + return false + } + + // Fetch each remote PR's details and index by head branch name. + prByBranch := make(map[string]*github.PullRequest, len(remotePRNumbers)) + for _, num := range remotePRNumbers { + pr, err := client.FindPRByNumber(num) if err != nil || pr == nil { continue } + prByBranch[pr.HeadRefName] = pr + } + // Match remote PRs to local branches. + for i := range s.Branches { + b := &s.Branches[i] + pr, ok := prByBranch[b.Branch] + if !ok { + continue + } b.PullRequest = &stack.PullRequestRef{ Number: pr.Number, ID: pr.ID, @@ -260,6 +360,8 @@ func syncStackPRs(cfg *config.Config, s *stack.Stack) { } b.Queued = pr.IsQueued() } + + return true } // updateBaseSHAs refreshes the Base and Head SHAs for all active branches diff --git a/cmd/utils_test.go b/cmd/utils_test.go index 8090132..2f9c7b4 100644 --- a/cmd/utils_test.go +++ b/cmd/utils_test.go @@ -9,8 +9,10 @@ import ( "github.com/AlecAivazis/survey/v2/terminal" "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" + "github.com/github/gh-stack/internal/github" "github.com/github/gh-stack/internal/stack" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsInterruptError_DirectMatch(t *testing.T) { @@ -244,6 +246,358 @@ func TestResolvePR_URLPrecedesNumber(t *testing.T) { assert.Equal(t, 99, br.PullRequest.Number) } +func TestSyncStackPRs_NoTrackedPR_OnlyAdoptsOpenPRs(t *testing.T) { + // A branch with no tracked PR should only adopt OPEN PRs, + // not stale merged/closed PRs from a previous branch name usage. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "reused-branch"}, // no PullRequest + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + // FindPRForBranch (OPEN only) returns nil — no open PR. + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return nil, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + // Branch should still have no PR tracked. + assert.Nil(t, s.Branches[0].PullRequest) +} + +func TestSyncStackPRs_NoTrackedPR_AdoptsOpenPR(t *testing.T) { + // A branch with no tracked PR should adopt an OPEN PR it discovers. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "feature"}, // no PullRequest + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 99, + ID: "PR_99", + URL: "https://github.com/o/r/pull/99", + State: "OPEN", + }, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 99, s.Branches[0].PullRequest.Number) + assert.False(t, s.Branches[0].PullRequest.Merged) +} + +func TestSyncStackPRs_TrackedPR_DetectsMerge(t *testing.T) { + // A branch with a tracked PR should detect when that PR gets merged. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "feature", + PullRequest: &stack.PullRequestRef{ + Number: 42, + ID: "PR_42", + URL: "https://github.com/o/r/pull/42", + }, + }, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 42, + ID: "PR_42", + URL: "https://github.com/o/r/pull/42", + State: "MERGED", + Merged: true, + }, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 42, s.Branches[0].PullRequest.Number) + assert.True(t, s.Branches[0].PullRequest.Merged) +} + +func TestSyncStackPRs_MergedBranch_StaysMerged(t *testing.T) { + // A merged branch should stay merged — no API calls, no changes. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "merged-branch", + PullRequest: &stack.PullRequestRef{ + Number: 20, + ID: "PR_20", + URL: "https://github.com/o/r/pull/20", + Merged: true, + }, + }, + }, + } + + apiCalled := false + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + apiCalled = true + return nil, nil + }, + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + apiCalled = true + return nil, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 20, s.Branches[0].PullRequest.Number) + assert.True(t, s.Branches[0].PullRequest.Merged) + assert.False(t, apiCalled, "no API calls should be made for merged branches") +} + +func TestSyncStackPRs_ClosedPR_ReplacedByOpenPR(t *testing.T) { + // A tracked PR that was closed (not merged) should be replaced + // by a new OPEN PR if one exists. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "feature", + PullRequest: &stack.PullRequestRef{ + Number: 10, + ID: "PR_10", + URL: "https://github.com/o/r/pull/10", + }, + }, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 10, + State: "CLOSED", + Merged: false, + }, nil + }, + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 15, + ID: "PR_15", + URL: "https://github.com/o/r/pull/15", + State: "OPEN", + }, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 15, s.Branches[0].PullRequest.Number) + assert.False(t, s.Branches[0].PullRequest.Merged) +} + +func TestSyncStackPRs_TrackedOpenPR_UpdatesQueued(t *testing.T) { + // A tracked OPEN PR that enters a merge queue should have Queued set. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "feature", + PullRequest: &stack.PullRequestRef{ + Number: 42, + ID: "PR_42", + URL: "https://github.com/o/r/pull/42", + }, + }, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 42, + State: "OPEN", + MergeQueueEntry: &github.MergeQueueEntry{ + ID: "MQ_1", + }, + }, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + assert.True(t, s.Branches[0].Queued) +} + +func TestSyncStackPRs_ClosedPR_NoReplacement_ClearsPR(t *testing.T) { + // A tracked PR that was closed with no replacement OPEN PR should + // have its PR ref cleared so it doesn't appear as an active PR. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "feature", + PullRequest: &stack.PullRequestRef{ + Number: 10, + ID: "PR_10", + URL: "https://github.com/o/r/pull/10", + }, + Queued: true, + }, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 10, + State: "CLOSED", + Merged: false, + }, nil + }, + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return nil, nil // no open replacement + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + assert.Nil(t, s.Branches[0].PullRequest) + assert.False(t, s.Branches[0].Queued) +} + +func TestSyncStackPRs_RemoteStack_UsesStackAPI(t *testing.T) { + // When the stack has a remote ID, sync should use the stack API + // as source of truth, matching PRs to branches by head ref name. + s := &stack.Stack{ + ID: "100", + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + {Branch: "b2"}, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{ + {ID: 100, PullRequests: []int{10, 11}}, + }, nil + }, + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 10: + return &github.PullRequest{Number: 10, ID: "PR_10", URL: "https://github.com/o/r/pull/10", HeadRefName: "b1", State: "OPEN"}, nil + case 11: + return &github.PullRequest{Number: 11, ID: "PR_11", URL: "https://github.com/o/r/pull/11", HeadRefName: "b2", State: "MERGED", Merged: true}, nil + } + return nil, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + // b1 should be tracked with open PR + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 10, s.Branches[0].PullRequest.Number) + assert.False(t, s.Branches[0].PullRequest.Merged) + + // b2 should be tracked with merged PR (stack API keeps closed/merged PRs) + require.NotNil(t, s.Branches[1].PullRequest) + assert.Equal(t, 11, s.Branches[1].PullRequest.Number) + assert.True(t, s.Branches[1].PullRequest.Merged) +} + +func TestSyncStackPRs_RemoteStack_ClosedPRStaysAssociated(t *testing.T) { + // When using the stack API, a closed (not merged) PR should remain + // associated — the stack API is the source of truth, not PR state. + s := &stack.Stack{ + ID: "200", + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "feature", PullRequest: &stack.PullRequestRef{Number: 5}}, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{ + {ID: 200, PullRequests: []int{5}}, + }, nil + }, + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + return &github.PullRequest{Number: 5, ID: "PR_5", URL: "https://github.com/o/r/pull/5", HeadRefName: "feature", State: "CLOSED"}, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + // PR should still be associated (not cleared), because the stack API says it's part of the stack. + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 5, s.Branches[0].PullRequest.Number) + assert.False(t, s.Branches[0].PullRequest.Merged) +} + +func TestSyncStackPRs_RemoteStack_FallsBackOnAPIError(t *testing.T) { + // If the stack API fails, fall back to local discovery. + s := &stack.Stack{ + ID: "300", + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "feature"}, + }, + } + + cfg, outR, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + ListStacksFn: func() ([]github.RemoteStack, error) { + return nil, fmt.Errorf("API error") + }, + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return &github.PullRequest{Number: 77, ID: "PR_77", URL: "https://github.com/o/r/pull/77", State: "OPEN"}, nil + }, + } + + syncStackPRs(cfg, s) + collectOutput(cfg, outR, errR) + + // Should have fallen back to local discovery and found the open PR. + require.NotNil(t, s.Branches[0].PullRequest) + assert.Equal(t, 77, s.Branches[0].PullRequest.Number) +} + func TestParsePRURL(t *testing.T) { tests := []struct { name string diff --git a/cmd/view_test.go b/cmd/view_test.go index 8d1c507..44bdc5e 100644 --- a/cmd/view_test.go +++ b/cmd/view_test.go @@ -311,18 +311,19 @@ func TestViewShort_QueuedStack(t *testing.T) { // Mock GitHub client to return b1 as queued (MergeQueueEntry set) cfg, outR, _ := config.NewTestConfig() cfg.GitHubClientOverride = &github.MockClient{ - FindAnyPRForBranchFn: func(branch string) (*github.PullRequest, error) { - switch branch { - case "b1": + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 1: return &github.PullRequest{ Number: 1, ID: "PR_1", + State: "OPEN", MergeQueueEntry: &github.MergeQueueEntry{ID: "MQE_1"}, }, nil - case "b2": - return &github.PullRequest{Number: 2, ID: "PR_2"}, nil - case "b3": - return &github.PullRequest{Number: 3, ID: "PR_3"}, nil + case 2: + return &github.PullRequest{Number: 2, ID: "PR_2", State: "OPEN"}, nil + case 3: + return &github.PullRequest{Number: 3, ID: "PR_3", State: "OPEN"}, nil } return nil, nil }, @@ -372,16 +373,17 @@ func TestViewShort_MixedQueuedAndMerged(t *testing.T) { // b1 is merged (persisted), b2 is queued (from API) cfg, outR, _ := config.NewTestConfig() cfg.GitHubClientOverride = &github.MockClient{ - FindAnyPRForBranchFn: func(branch string) (*github.PullRequest, error) { - switch branch { - case "b2": + FindPRByNumberFn: func(number int) (*github.PullRequest, error) { + switch number { + case 2: return &github.PullRequest{ Number: 2, ID: "PR_2", + State: "OPEN", MergeQueueEntry: &github.MergeQueueEntry{ID: "MQE_2"}, }, nil - case "b3": - return &github.PullRequest{Number: 3, ID: "PR_3"}, nil + case 3: + return &github.PullRequest{Number: 3, ID: "PR_3", State: "OPEN"}, nil } return nil, nil }, diff --git a/internal/github/github.go b/internal/github/github.go index 4af97d1..cea2c8c 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -347,6 +347,9 @@ func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { } n := query.Repository.PullRequest + if n.Number == 0 && n.ID == "" { + return nil, nil + } return &PullRequest{ ID: n.ID, Number: n.Number, diff --git a/internal/tui/stackview/data.go b/internal/tui/stackview/data.go index ffd21e8..cd51706 100644 --- a/internal/tui/stackview/data.go +++ b/internal/tui/stackview/data.go @@ -68,10 +68,16 @@ func LoadBranchNodes(cfg *config.Config, s *stack.Stack, currentBranch string) [ } } - // Fetch enriched PR details + // Fetch enriched PR details. + // Only adopt the result if it matches our tracked PR or is OPEN. + // This prevents showing stale merged/closed PR details when a + // branch name was reused from a previously merged PR. if clientErr == nil { if pr, err := client.FindPRDetailsForBranch(b.Branch); err == nil && pr != nil { - node.PR = pr + tracked := b.PullRequest != nil && b.PullRequest.Number == pr.Number + if tracked || pr.State == "OPEN" { + node.PR = pr + } } } diff --git a/internal/tui/stackview/data_test.go b/internal/tui/stackview/data_test.go index f0ff1e2..4faa563 100644 --- a/internal/tui/stackview/data_test.go +++ b/internal/tui/stackview/data_test.go @@ -100,3 +100,129 @@ func TestLoadBranchNodes_LinearBranchStillUsesMergeBase(t *testing.T) { assert.Len(t, nodes[0].FilesChanged, 1) assert.True(t, nodes[0].IsLinear) } + +func TestLoadBranchNodes_IgnoresStaleMergedPRDetails(t *testing.T) { + // When FindPRDetailsForBranch returns a merged PR that doesn't match + // the branch's tracked PR, it should be ignored (stale from branch reuse). + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "reused-branch"}, // no tracked PR + }, + } + + restore := git.SetOps(&git.MockOps{ + IsAncestorFn: func(a, b string) (bool, error) { return true, nil }, + MergeBaseFn: func(a, b string) (string, error) { return "abc", nil }, + LogRangeFn: func(a, b string) ([]git.CommitInfo, error) { return nil, nil }, + DiffStatFilesFn: func(a, b string) ([]git.FileDiffStat, error) { + return nil, nil + }, + }) + defer restore() + + cfg, outW, errW := config.NewTestConfig() + defer outW.Close() + defer errW.Close() + cfg.GitHubClientOverride = &ghapi.MockClient{ + FindPRDetailsForBranchFn: func(branch string) (*ghapi.PRDetails, error) { + return &ghapi.PRDetails{ + Number: 20, + Title: "Old merged PR", + State: "MERGED", + Merged: true, + }, nil + }, + } + + nodes := LoadBranchNodes(cfg, s, "other") + + require.Len(t, nodes, 1) + assert.Nil(t, nodes[0].PR, "stale merged PR should not be adopted") +} + +func TestLoadBranchNodes_ShowsTrackedMergedPRDetails(t *testing.T) { + // When FindPRDetailsForBranch returns a merged PR that matches the + // branch's tracked PR number, it should be shown (legitimately merged). + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + { + Branch: "merged-branch", + PullRequest: &stack.PullRequestRef{ + Number: 20, + Merged: true, + }, + }, + }, + } + + restore := git.SetOps(&git.MockOps{ + IsAncestorFn: func(a, b string) (bool, error) { return true, nil }, + MergeBaseFn: func(a, b string) (string, error) { return "abc", nil }, + LogRangeFn: func(a, b string) ([]git.CommitInfo, error) { return nil, nil }, + DiffStatFilesFn: func(a, b string) ([]git.FileDiffStat, error) { + return nil, nil + }, + }) + defer restore() + + cfg, outW, errW := config.NewTestConfig() + defer outW.Close() + defer errW.Close() + cfg.GitHubClientOverride = &ghapi.MockClient{ + FindPRDetailsForBranchFn: func(branch string) (*ghapi.PRDetails, error) { + return &ghapi.PRDetails{ + Number: 20, + Title: "Legitimately merged PR", + State: "MERGED", + Merged: true, + }, nil + }, + } + + nodes := LoadBranchNodes(cfg, s, "other") + + require.Len(t, nodes, 1) + require.NotNil(t, nodes[0].PR, "tracked merged PR should be shown") + assert.Equal(t, 20, nodes[0].PR.Number) +} + +func TestLoadBranchNodes_ShowsOpenPRDetails(t *testing.T) { + // An OPEN PR should always be shown, even without a tracked PR. + s := &stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "feature"}, // no tracked PR + }, + } + + restore := git.SetOps(&git.MockOps{ + IsAncestorFn: func(a, b string) (bool, error) { return true, nil }, + MergeBaseFn: func(a, b string) (string, error) { return "abc", nil }, + LogRangeFn: func(a, b string) ([]git.CommitInfo, error) { return nil, nil }, + DiffStatFilesFn: func(a, b string) ([]git.FileDiffStat, error) { + return nil, nil + }, + }) + defer restore() + + cfg, outW, errW := config.NewTestConfig() + defer outW.Close() + defer errW.Close() + cfg.GitHubClientOverride = &ghapi.MockClient{ + FindPRDetailsForBranchFn: func(branch string) (*ghapi.PRDetails, error) { + return &ghapi.PRDetails{ + Number: 50, + Title: "Active PR", + State: "OPEN", + }, nil + }, + } + + nodes := LoadBranchNodes(cfg, s, "other") + + require.Len(t, nodes, 1) + require.NotNil(t, nodes[0].PR, "OPEN PR should be shown") + assert.Equal(t, 50, nodes[0].PR.Number) +}