diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 3494851..f42111c 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "encoding/json" "fmt" "os" @@ -10,67 +11,41 @@ import ( "strings" "testing" + "github.com/nvandessel/frond/internal/driver" "github.com/nvandessel/frond/internal/state" "github.com/spf13/pflag" ) -// setupTestEnv creates a temp git repo with an initial commit, a fake gh -// script, and chdir into the repo. It restores state on cleanup. -func setupTestEnv(t *testing.T) string { +// setupTestEnv creates a temp directory, overrides GitCommonDir and +// injects a mock driver. No real git or gh commands are needed. +func setupTestEnv(t *testing.T) (*driver.Mock, string) { t.Helper() dir := t.TempDir() - - gitEnv := []string{ - "GIT_AUTHOR_NAME=Test User", - "GIT_AUTHOR_EMAIL=test@example.com", - "GIT_COMMITTER_NAME=Test User", - "GIT_COMMITTER_EMAIL=test@example.com", - "GIT_CONFIG_NOSYSTEM=1", - "HOME=" + dir, - } - - // Run git commands in the temp dir. - gitCmd := func(args ...string) { - t.Helper() - cmd := exec.Command("git", args...) - cmd.Dir = dir - cmd.Env = append(os.Environ(), gitEnv...) - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("setup git %s: %s\n%s", strings.Join(args, " "), err, out) - } + gitDir := filepath.Join(dir, ".git") + if err := os.MkdirAll(gitDir, 0o755); err != nil { + t.Fatal(err) } - gitCmd("init", "-b", "main") - gitCmd("commit", "--allow-empty", "-m", "init") + orig := state.GitCommonDir + state.GitCommonDir = func(_ context.Context) (string, error) { return gitDir, nil } + t.Cleanup(func() { state.GitCommonDir = orig }) - // Set env vars for subprocesses. - for _, e := range gitEnv { - parts := strings.SplitN(e, "=", 2) - t.Setenv(parts[0], parts[1]) + mock := driver.NewMock() + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + return &driver.PushResult{PRNumber: 42, Created: opts.ExistingPR == nil}, nil } - - // chdir to the repo. - origDir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - if err := os.Chdir(dir); err != nil { - t.Fatal(err) + mock.PRStateFn = func(_ context.Context, _ int) (string, error) { + return "OPEN", nil } - t.Cleanup(func() { os.Chdir(origDir) }) - // Install a fake gh script (platform-appropriate). - ghDir := t.TempDir() - installFakeGH(t, ghDir) - t.Setenv("PATH", ghDir+string(os.PathListSeparator)+os.Getenv("PATH")) + driverOverride = mock + t.Cleanup(func() { driverOverride = nil }) - // Reset global state and cobra flags between tests. - jsonOut = false resetCobraFlags() + jsonOut = false - return dir + return mock, dir } // moduleRoot caches the repo root path, found before any test does os.Chdir. @@ -122,15 +97,17 @@ func TestMain(m *testing.M) { os.Exit(code) } -// installFakeGH copies the pre-built fakegh binary into the given directory as "gh". -func installFakeGH(t *testing.T, dir string) { +// withFakeGH installs the pre-built fakegh binary on PATH for tests that +// need the gh comment API (e.g., stack comment tests). Call after setupTestEnv. +func withFakeGH(t *testing.T) { t.Helper() + ghDir := t.TempDir() binName := "gh" if runtime.GOOS == "windows" { binName = "gh.exe" } - dst := filepath.Join(dir, binName) + dst := filepath.Join(ghDir, binName) // Hard-link (fast) or copy the pre-built binary. if err := os.Link(fakeGHBin, dst); err != nil { @@ -144,6 +121,7 @@ func installFakeGH(t *testing.T, dir string) { } } + t.Setenv("PATH", ghDir+string(os.PathListSeparator)+os.Getenv("PATH")) t.Setenv("FAKEGH_FAIL", "") t.Setenv("FAKEGH_FAIL_API", "") t.Setenv("FAKEGH_PR_COUNTER", "") @@ -190,22 +168,19 @@ func runTier(t *testing.T, args ...string) error { } func TestNewCreatesAndTracks(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) err := runTier(t, "new", "feature-x") if err != nil { t.Fatalf("frond new: %v", err) } - // Verify git branch was created and checked out. - branchCmd := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD") - branchCmd.Dir = dir - out, err := branchCmd.Output() - if err != nil { - t.Fatalf("git rev-parse: %v", err) + // Verify mock branch was created and checked out. + if mock.CurrentBranchName != "feature-x" { + t.Errorf("current branch = %q, want %q", mock.CurrentBranchName, "feature-x") } - if got := strings.TrimSpace(string(out)); got != "feature-x" { - t.Errorf("current branch = %q, want %q", got, "feature-x") + if !mock.Branches["feature-x"] { + t.Error("branch 'feature-x' not created in mock") } // Verify frond.json has the branch. @@ -223,17 +198,15 @@ func TestNewCreatesAndTracks(t *testing.T) { } func TestNewWithOnFlag(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) // Create a first branch. - err := runTier(t, "new", "step-1") - if err != nil { + if err := runTier(t, "new", "step-1"); err != nil { t.Fatalf("frond new step-1: %v", err) } // Create a stacked branch on top. - err = runTier(t, "new", "step-2", "--on", "step-1") - if err != nil { + if err := runTier(t, "new", "step-2", "--on", "step-1"); err != nil { t.Fatalf("frond new step-2: %v", err) } @@ -261,21 +234,10 @@ func TestNewDuplicateBranchFails(t *testing.T) { } func TestTrackExistingBranch(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) - // Create a git branch manually (not via frond). - gitCmd := exec.Command("git", "checkout", "-b", "existing-branch", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout -b: %s\n%s", err, out) - } - - // Switch back to main. - gitCmd = exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout main: %s\n%s", err, out) - } + // Pre-create a branch in the mock (simulates existing git branch). + mock.Branches["existing-branch"] = true err := runTier(t, "track", "existing-branch", "--on", "main") if err != nil { @@ -309,7 +271,7 @@ func TestTrackAlreadyTrackedFails(t *testing.T) { } func TestUntrackRemovesBranch(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) // Create two stacked branches. if err := runTier(t, "new", "parent-branch"); err != nil { @@ -402,7 +364,7 @@ func TestStatusNoStateFails(t *testing.T) { } func TestNewCycleDetection(t *testing.T) { - setupTestEnv(t) + mock, _ := setupTestEnv(t) // Create branch A. if err := runTier(t, "new", "branch-a"); err != nil { @@ -410,43 +372,15 @@ func TestNewCycleDetection(t *testing.T) { } // Create branch B that depends on A. + mock.CurrentBranchName = "main" if err := runTier(t, "new", "branch-b", "--on", "main", "--after", "branch-a"); err != nil { t.Fatalf("frond new branch-b: %v", err) } - // Try to create C with --after=branch-b AND on branch-a, but also adding - // a circular dep. Actually, a direct cycle: create C --after=branch-b, - // then try to create D --after=C --after=... that forms a cycle. - // Simplest cycle: A --after B, B --after A. - // B already depends on A. Try creating C --after=branch-b where - // branch-b has after=[branch-a], and C has after=[branch-b]. - // That's not a cycle, just a chain. - - // For a real cycle: create C that depends on branch-b, - // then try to make branch-a depend on C (but we can't modify after post-creation). - // Instead: create C --on main --after branch-b, then D --on main --after C,branch-a - // This creates: A -> B -> C -> D, D -> A which is a cycle. - - // Simplest approach: branch-b after=[branch-a]. Create branch-c --after=branch-b. - // Then create branch-d --after=branch-c,branch-a is NOT a cycle. - // We need: create branch-c --after=branch-b, then branch-a --after=branch-c - // But branch-a already exists. - - // Use track to add a branch with a cyclic dep. - // Create branch-c in git manually. - gitCmd := exec.Command("git", "checkout", "-b", "branch-c", "main") - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + // Pre-create branch-c in mock so track can find it. + mock.Branches["branch-c"] = true - // Try to track branch-c --on main --after=branch-b. - // Then try to create branch-a2 that dep on branch-c. No... - // Let me just test: create branch-c with --after=branch-a, - // where branch-a would need --after=branch-c (cycle). - // But since we can't modify existing branches' after lists, - // test the simpler case: track branch-c --after branch-a,branch-c → self-cycle. - - // Actually the simplest: self-dependency. + // Try self-dependency — should fail. err := runTier(t, "track", "branch-c", "--on", "main", "--after", "branch-c") if err == nil { t.Fatal("expected cycle detection error") @@ -457,14 +391,14 @@ func TestNewCycleDetection(t *testing.T) { } func TestNewInheritsParentFromCurrentBranch(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) // Create first branch. if err := runTier(t, "new", "base-feature"); err != nil { t.Fatalf("frond new base-feature: %v", err) } - // We're now on base-feature. Create another without --on. + // We're now on base-feature (mock auto-checks out). Create another without --on. // It should inherit base-feature as parent. if err := runTier(t, "new", "sub-feature"); err != nil { t.Fatalf("frond new sub-feature: %v", err) @@ -478,40 +412,16 @@ func TestNewInheritsParentFromCurrentBranch(t *testing.T) { } func TestPushCreatesNewPR(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) - // Create a tracked branch with a commit. + // Create a tracked branch. if err := runTier(t, "new", "pr-branch"); err != nil { t.Fatalf("frond new: %v", err) } - // Add a commit so push has something. - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "feature work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Push needs a remote. Create a bare remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - - // Add the bare repo as "origin". - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - - // Push main first so origin has a "main" branch. - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) + // Ensure we're on the branch. + if mock.CurrentBranchName != "pr-branch" { + t.Fatalf("expected current branch pr-branch, got %s", mock.CurrentBranchName) } err := runTier(t, "push") @@ -553,32 +463,13 @@ func TestRemoveFromSlice(t *testing.T) { } func TestSyncNothingToDo(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) // Create a tracked branch. if err := runTier(t, "new", "sync-branch"); err != nil { t.Fatalf("frond new: %v", err) } - // Set up a remote so fetch works. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - // Push main so origin has it. - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - // Sync should succeed with "already up to date". err := runTier(t, "sync") if err != nil { @@ -606,7 +497,7 @@ func TestHumanizeTitle(t *testing.T) { } func TestPushUntrackedBranchFails(t *testing.T) { - dir := setupTestEnv(t) + mock, _ := setupTestEnv(t) // Initialize state by creating one branch. if err := runTier(t, "new", "tracked-one"); err != nil { @@ -614,11 +505,8 @@ func TestPushUntrackedBranchFails(t *testing.T) { } // Switch to an untracked branch. - gitCmd := exec.Command("git", "checkout", "-b", "untracked-branch") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.Branches["untracked-branch"] = true + mock.CurrentBranchName = "untracked-branch" err := runTier(t, "push") if err == nil { @@ -684,28 +572,20 @@ func TestNewWithJSONOutput(t *testing.T) { } func TestNewWithAfterDeps(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) // Create two branches. if err := runTier(t, "new", "dep-a"); err != nil { t.Fatalf("frond new dep-a: %v", err) } // Go back to main so next new defaults to main. - gitCmd := exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.CurrentBranchName = "main" if err := runTier(t, "new", "dep-b"); err != nil { t.Fatalf("frond new dep-b: %v", err) } // Go back to main. - gitCmd = exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.CurrentBranchName = "main" // Create a branch with --after deps. if err := runTier(t, "new", "dep-c", "--on", "main", "--after", "dep-a,dep-b"); err != nil { @@ -722,7 +602,6 @@ func TestNewWithAfterDeps(t *testing.T) { func TestNewInvalidBranchName(t *testing.T) { setupTestEnv(t) - // Branch name with ".." is invalid and gets past cobra flag parsing. err := runTier(t, "new", "a..b") if err == nil { t.Fatal("expected error for branch name with '..'") @@ -757,19 +636,10 @@ func TestNewAfterDepNotTracked(t *testing.T) { } func TestTrackWithJSONOutput(t *testing.T) { - dir := setupTestEnv(t) + mock, _ := setupTestEnv(t) - // Create a branch in git manually. - gitCmd := exec.Command("git", "checkout", "-b", "json-track", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } - gitCmd = exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + // Pre-create branch in mock. + mock.Branches["json-track"] = true err := runTier(t, "track", "json-track", "--on", "main", "--json") if err != nil { @@ -815,7 +685,7 @@ func TestUntrackWithJSONOutput(t *testing.T) { } func TestUntrackCurrentBranch(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) // Create and stay on the branch. if err := runTier(t, "new", "current-br"); err != nil { @@ -835,7 +705,7 @@ func TestUntrackCurrentBranch(t *testing.T) { } func TestUntrackWithDepsAndChildren(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) // Create parent -> child chain with deps. if err := runTier(t, "new", "mid-branch"); err != nil { @@ -845,11 +715,7 @@ func TestUntrackWithDepsAndChildren(t *testing.T) { t.Fatalf("frond new child-a: %v", err) } // Go back to main, create another that depends on mid-branch. - gitCmd := exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.CurrentBranchName = "main" if err := runTier(t, "new", "dep-on-mid", "--on", "main", "--after", "mid-branch"); err != nil { t.Fatalf("frond new dep-on-mid: %v", err) } @@ -905,86 +771,39 @@ func TestCompletionInvalidShell(t *testing.T) { } func TestPushExistingPRUpdates(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) - // Create a tracked branch with a commit. + // Create a tracked branch. if err := runTier(t, "new", "update-pr-branch"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Set up remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - // First push creates a PR. if err := runTier(t, "push"); err != nil { t.Fatalf("first push: %v", err) } - // Add another commit. - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "more work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - // Second push should update the existing PR (not create new). err := runTier(t, "push") if err != nil { t.Fatalf("second push (update): %v", err) } + + // PR number should still be 42. + s := readState(t, dir) + b := s.Branches["update-pr-branch"] + if b.PR == nil || *b.PR != 42 { + t.Errorf("PR = %v, want 42", b.PR) + } } func TestPushWithTitleAndDraft(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) if err := runTier(t, "new", "draft-branch"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - err := runTier(t, "push", "-t", "My Custom Title", "--draft") if err != nil { t.Fatalf("frond push with title and draft: %v", err) @@ -992,35 +811,12 @@ func TestPushWithTitleAndDraft(t *testing.T) { } func TestPushWithJSONOutput(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) if err := runTier(t, "new", "json-push"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - err := runTier(t, "push", "--json") if err != nil { t.Fatalf("frond push --json: %v", err) @@ -1028,7 +824,7 @@ func TestPushWithJSONOutput(t *testing.T) { } func TestSyncNoBranches(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) // Create a branch and immediately untrack it so state exists but has no branches. if err := runTier(t, "new", "temp-branch"); err != nil { @@ -1038,24 +834,6 @@ func TestSyncNoBranches(t *testing.T) { t.Fatalf("frond untrack: %v", err) } - // Set up remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - // Sync with no branches should say "nothing to sync". err := runTier(t, "sync") if err != nil { @@ -1064,7 +842,7 @@ func TestSyncNoBranches(t *testing.T) { } func TestSyncNoBranchesJSON(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) if err := runTier(t, "new", "temp-branch"); err != nil { t.Fatalf("frond new: %v", err) @@ -1073,23 +851,6 @@ func TestSyncNoBranchesJSON(t *testing.T) { t.Fatalf("frond untrack: %v", err) } - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - err := runTier(t, "sync", "--json") if err != nil { t.Fatalf("frond sync --json (no branches): %v", err) @@ -1097,51 +858,14 @@ func TestSyncNoBranchesJSON(t *testing.T) { } func TestSyncRebasesTrackedBranch(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) // Create tracked branch. if err := runTier(t, "new", "rebase-me"); err != nil { t.Fatalf("frond new: %v", err) } - // Add a commit on the feature branch. - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "feature work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Go back to main and add a commit. - gitCmd = exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "main advance") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Set up remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - - // Sync should rebase rebase-me onto main. + // Sync should rebase rebase-me onto main (mock rebase is no-op). err := runTier(t, "sync") if err != nil { t.Fatalf("frond sync: %v", err) @@ -1149,29 +873,12 @@ func TestSyncRebasesTrackedBranch(t *testing.T) { } func TestSyncWithJSONOutput(t *testing.T) { - dir := setupTestEnv(t) + setupTestEnv(t) if err := runTier(t, "new", "sync-json"); err != nil { t.Fatalf("frond new: %v", err) } - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - err := runTier(t, "sync", "--json") if err != nil { t.Fatalf("frond sync --json: %v", err) @@ -1179,7 +886,7 @@ func TestSyncWithJSONOutput(t *testing.T) { } func TestStatusWithPRStates(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) // Create a tracked branch and manually set a PR number. if err := runTier(t, "new", "pr-status"); err != nil { @@ -1200,7 +907,7 @@ func TestStatusWithPRStates(t *testing.T) { t.Fatal(err) } - // Status with --fetch should exercise fetchPRStates and outputHuman with prStates. + // Status with --fetch should exercise fetchPRStates. err = runTier(t, "status", "--fetch") if err != nil { t.Fatalf("frond status --fetch: %v", err) @@ -1208,7 +915,7 @@ func TestStatusWithPRStates(t *testing.T) { } func TestStatusFetchJSON(t *testing.T) { - dir := setupTestEnv(t) + _, dir := setupTestEnv(t) if err := runTier(t, "new", "pr-json-status"); err != nil { t.Fatalf("frond new: %v", err) @@ -1236,47 +943,19 @@ func TestStatusFetchJSON(t *testing.T) { } func TestPushWithUnmetDeps(t *testing.T) { - dir := setupTestEnv(t) + mock, _ := setupTestEnv(t) // Create dep and dependent branches. if err := runTier(t, "new", "dep-branch"); err != nil { t.Fatalf("frond new dep-branch: %v", err) } - gitCmd := exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.CurrentBranchName = "main" if err := runTier(t, "new", "with-deps", "--on", "main", "--after", "dep-branch"); err != nil { t.Fatalf("frond new with-deps: %v", err) } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Set up remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - // Push should succeed but warn about unmet deps. err := runTier(t, "push") if err != nil { @@ -1285,39 +964,17 @@ func TestPushWithUnmetDeps(t *testing.T) { } func TestSyncBlockedBranch(t *testing.T) { - dir := setupTestEnv(t) + mock, _ := setupTestEnv(t) // Create two branches: blocker and blocked. if err := runTier(t, "new", "blocker"); err != nil { t.Fatalf("frond new blocker: %v", err) } - gitCmd := exec.Command("git", "checkout", "main") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git checkout: %s\n%s", err, out) - } + mock.CurrentBranchName = "main" if err := runTier(t, "new", "blocked-br", "--on", "main", "--after", "blocker"); err != nil { t.Fatalf("frond new blocked-br: %v", err) } - // Set up remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - // Sync should see blocked-br as blocked. err := runTier(t, "sync") if err != nil { @@ -1326,40 +983,18 @@ func TestSyncBlockedBranch(t *testing.T) { } func TestPushSkipsStackCommentForSinglePR(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) + withFakeGH(t) + mock.StackComments = true recordFile := filepath.Join(dir, "gh_calls.log") t.Setenv("FAKEGH_RECORD", recordFile) - // Create a single tracked branch with a commit. + // Create a single tracked branch. if err := runTier(t, "new", "solo-branch"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "feature work") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } - - // Set up a remote. - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } - err := runTier(t, "push") if err != nil { t.Fatalf("frond push: %v", err) @@ -1374,53 +1009,29 @@ func TestPushSkipsStackCommentForSinglePR(t *testing.T) { } } -// setupRemote creates a bare remote and adds it as "origin", pushing main. -func setupRemote(t *testing.T, dir string) { - t.Helper() - remoteDir := t.TempDir() - bareInit := exec.Command("git", "init", "--bare") - bareInit.Dir = remoteDir - if out, err := bareInit.CombinedOutput(); err != nil { - t.Fatalf("git init --bare: %s\n%s", err, out) - } - addRemote := exec.Command("git", "remote", "add", "origin", remoteDir) - addRemote.Dir = dir - if out, err := addRemote.CombinedOutput(); err != nil { - t.Fatalf("git remote add: %s\n%s", err, out) - } - pushMain := exec.Command("git", "push", "origin", "main") - pushMain.Dir = dir - if out, err := pushMain.CombinedOutput(); err != nil { - t.Fatalf("git push main: %s\n%s", err, out) - } -} - -// setupPRCounter enables incrementing PR numbers in fakegh and returns -// the path to the counter file. -func setupPRCounter(t *testing.T, dir string) { - t.Helper() - counterFile := filepath.Join(dir, "pr_counter") - os.WriteFile(counterFile, []byte("42\n"), 0o644) - t.Setenv("FAKEGH_PR_COUNTER", counterFile) -} - func TestPushCreatesStackComment(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) + withFakeGH(t) + mock.StackComments = true recordFile := filepath.Join(dir, "gh_calls.log") t.Setenv("FAKEGH_RECORD", recordFile) - setupPRCounter(t, dir) - setupRemote(t, dir) + + // Use incrementing PR numbers so each push gets a unique PR. + prCounter := 42 + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + if opts.ExistingPR != nil { + return &driver.PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + n := prCounter + prCounter++ + return &driver.PushResult{PRNumber: n, Created: true}, nil + } // Create two tracked branches so the stack has >= 2 PRs. if err := runTier(t, "new", "branch-a"); err != nil { t.Fatalf("frond new branch-a: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work on a") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } // Push branch-a to create PR #42 (single PR, no comments yet). if err := runTier(t, "push"); err != nil { @@ -1431,11 +1042,6 @@ func TestPushCreatesStackComment(t *testing.T) { if err := runTier(t, "new", "branch-b", "--on", "branch-a"); err != nil { t.Fatalf("frond new branch-b: %v", err) } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "work on b") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } // Clear the record file so we only see calls from this push. os.Remove(recordFile) @@ -1476,23 +1082,29 @@ func TestPushCreatesStackComment(t *testing.T) { } func TestPushUpdatesStackComment(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) + withFakeGH(t) + mock.StackComments = true recordFile := filepath.Join(dir, "gh_calls.log") t.Setenv("FAKEGH_RECORD", recordFile) t.Setenv("FAKEGH_EXISTING_COMMENT", "1") - setupPRCounter(t, dir) - setupRemote(t, dir) + + // Use incrementing PR numbers. + prCounter := 42 + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + if opts.ExistingPR != nil { + return &driver.PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + n := prCounter + prCounter++ + return &driver.PushResult{PRNumber: n, Created: true}, nil + } // Create two tracked branches so the stack has >= 2 PRs. if err := runTier(t, "new", "update-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work on a") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } // Push branch-a to create its PR. if err := runTier(t, "push"); err != nil { @@ -1503,11 +1115,6 @@ func TestPushUpdatesStackComment(t *testing.T) { if err := runTier(t, "new", "update-branch-b", "--on", "update-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "work on b") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } // Clear the record file so we only see calls from this push. os.Remove(recordFile) @@ -1533,22 +1140,28 @@ func TestPushUpdatesStackComment(t *testing.T) { } func TestPushStackCommentErrorNonFatal(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) + withFakeGH(t) + mock.StackComments = true recordFile := filepath.Join(dir, "gh_calls.log") t.Setenv("FAKEGH_RECORD", recordFile) - setupPRCounter(t, dir) - setupRemote(t, dir) + + // Use incrementing PR numbers. + prCounter := 42 + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + if opts.ExistingPR != nil { + return &driver.PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + n := prCounter + prCounter++ + return &driver.PushResult{PRNumber: n, Created: true}, nil + } // Create two branches with PRs so stack comments are attempted. if err := runTier(t, "new", "err-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work on a") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } if err := runTier(t, "push"); err != nil { t.Fatalf("frond push err-branch-a: %v", err) } @@ -1556,11 +1169,6 @@ func TestPushStackCommentErrorNonFatal(t *testing.T) { if err := runTier(t, "new", "err-branch-b", "--on", "err-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "work on b") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } // Make only API calls fail — pr view/edit still work. t.Setenv("FAKEGH_FAIL_API", "1") @@ -1573,22 +1181,28 @@ func TestPushStackCommentErrorNonFatal(t *testing.T) { } func TestSyncUpdatesMergedComments(t *testing.T) { - dir := setupTestEnv(t) + mock, dir := setupTestEnv(t) + withFakeGH(t) + mock.StackComments = true recordFile := filepath.Join(dir, "gh_calls.log") t.Setenv("FAKEGH_RECORD", recordFile) - setupPRCounter(t, dir) - setupRemote(t, dir) + + // Use incrementing PR numbers. + prCounter := 42 + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + if opts.ExistingPR != nil { + return &driver.PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + n := prCounter + prCounter++ + return &driver.PushResult{PRNumber: n, Created: true}, nil + } // Create two branches with PRs. if err := runTier(t, "new", "merge-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd := exec.Command("git", "commit", "--allow-empty", "-m", "work on a") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } if err := runTier(t, "push"); err != nil { t.Fatalf("frond push: %v", err) } @@ -1596,17 +1210,14 @@ func TestSyncUpdatesMergedComments(t *testing.T) { if err := runTier(t, "new", "merge-branch-b", "--on", "merge-branch-a"); err != nil { t.Fatalf("frond new: %v", err) } - gitCmd = exec.Command("git", "commit", "--allow-empty", "-m", "work on b") - gitCmd.Dir = dir - if out, err := gitCmd.CombinedOutput(); err != nil { - t.Fatalf("git commit: %s\n%s", err, out) - } if err := runTier(t, "push"); err != nil { t.Fatalf("frond push: %v", err) } - // Make fakegh report all PRs as MERGED. - t.Setenv("FAKEGH_PR_STATE", "MERGED") + // Make mock report all PRs as MERGED. + mock.PRStateFn = func(_ context.Context, _ int) (string, error) { + return "MERGED", nil + } // Clear record to isolate sync calls. os.Remove(recordFile) @@ -1648,6 +1259,52 @@ func readGHCalls(t *testing.T, recordFile string) []string { return lines } +func TestPushSkipsStackCommentsWhenDriverUnsupported(t *testing.T) { + mock, dir := setupTestEnv(t) + withFakeGH(t) + // StackComments defaults to false, simulating a Graphite-like driver. + + recordFile := filepath.Join(dir, "gh_calls.log") + t.Setenv("FAKEGH_RECORD", recordFile) + + // Use incrementing PR numbers. + prCounter := 42 + mock.PushFn = func(_ context.Context, opts driver.PushOpts) (*driver.PushResult, error) { + if opts.ExistingPR != nil { + return &driver.PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + n := prCounter + prCounter++ + return &driver.PushResult{PRNumber: n, Created: true}, nil + } + + // Create two branches with PRs — enough for stack comments to trigger. + if err := runTier(t, "new", "no-comment-a"); err != nil { + t.Fatalf("frond new: %v", err) + } + if err := runTier(t, "push"); err != nil { + t.Fatalf("frond push: %v", err) + } + + if err := runTier(t, "new", "no-comment-b", "--on", "no-comment-a"); err != nil { + t.Fatalf("frond new: %v", err) + } + + os.Remove(recordFile) + + if err := runTier(t, "push"); err != nil { + t.Fatalf("frond push: %v", err) + } + + // With StackComments=false, no comment API calls should be made. + calls := readGHCalls(t, recordFile) + for _, call := range calls { + if strings.Contains(call, "api") && strings.Contains(call, "comments") { + t.Errorf("expected no comment API calls with StackComments=false, got: %s", call) + } + } +} + func TestNewEmptySyncResult(t *testing.T) { r := newEmptySyncResult() if r.Merged == nil || r.Rebased == nil || r.Unblocked == nil || r.Conflicts == nil { @@ -1657,3 +1314,191 @@ func TestNewEmptySyncResult(t *testing.T) { t.Error("newEmptySyncResult should initialize all maps") } } + +func TestInitDefault(t *testing.T) { + _, dir := setupTestEnv(t) + + err := runTier(t, "init") + if err != nil { + t.Fatalf("frond init: %v", err) + } + + s := readState(t, dir) + if s.Driver != "" { + t.Errorf("driver = %q, want empty (native)", s.Driver) + } + if s.Trunk != "main" { + t.Errorf("trunk = %q, want main", s.Trunk) + } +} + +func TestInitJSON(t *testing.T) { + setupTestEnv(t) + + err := runTier(t, "init", "--json") + if err != nil { + t.Fatalf("frond init --json: %v", err) + } +} + +func TestInitUnknownDriver(t *testing.T) { + setupTestEnv(t) + + err := runTier(t, "init", "--driver", "bogus") + if err == nil { + t.Fatal("expected error for unknown driver") + } + if !strings.Contains(err.Error(), "unknown driver") { + t.Errorf("error = %q, want containing 'unknown driver'", err.Error()) + } +} + +func TestInitPreservesExistingState(t *testing.T) { + _, dir := setupTestEnv(t) + + // Create some state first. + if err := runTier(t, "new", "existing-branch"); err != nil { + t.Fatalf("frond new: %v", err) + } + + // Init should not blow away existing branches. + if err := runTier(t, "init"); err != nil { + t.Fatalf("frond init: %v", err) + } + + s := readState(t, dir) + if _, ok := s.Branches["existing-branch"]; !ok { + t.Error("init should preserve existing branches") + } +} + +func TestSyncMergedPR(t *testing.T) { + mock, dir := setupTestEnv(t) + + // Create parent and child branches. + if err := runTier(t, "new", "merged-branch"); err != nil { + t.Fatalf("frond new merged-branch: %v", err) + } + if err := runTier(t, "new", "child-of-merged", "--on", "merged-branch"); err != nil { + t.Fatalf("frond new child-of-merged: %v", err) + } + + // Manually assign PR numbers to state. + s := readState(t, dir) + pr1 := 10 + b := s.Branches["merged-branch"] + b.PR = &pr1 + s.Branches["merged-branch"] = b + pr2 := 20 + c := s.Branches["child-of-merged"] + c.PR = &pr2 + s.Branches["child-of-merged"] = c + data, err := json.Marshal(s) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, ".git", "frond.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + // Mock PRState to return MERGED for PR #10. + mock.PRStateFn = func(_ context.Context, prNumber int) (string, error) { + if prNumber == 10 { + return "MERGED", nil + } + return "OPEN", nil + } + + // Track retarget calls. + var retargetCalls []int + mock.RetargetPRFn = func(_ context.Context, prNumber int, _ string) error { + retargetCalls = append(retargetCalls, prNumber) + return nil + } + + err = runTier(t, "sync") + if err != nil { + t.Fatalf("frond sync: %v", err) + } + + // merged-branch should be removed, child reparented to main. + s = readState(t, dir) + if _, ok := s.Branches["merged-branch"]; ok { + t.Error("merged-branch should be removed from state") + } + child := s.Branches["child-of-merged"] + if child.Parent != "main" { + t.Errorf("child parent = %q, want main", child.Parent) + } + + // Child PR should have been retargeted. + found := false + for _, n := range retargetCalls { + if n == 20 { + found = true + } + } + if !found { + t.Error("expected RetargetPR called for child PR #20") + } +} + +func TestSyncRebaseConflict(t *testing.T) { + mock, _ := setupTestEnv(t) + + if err := runTier(t, "new", "conflict-branch"); err != nil { + t.Fatalf("frond new: %v", err) + } + + // Mock rebase to return a conflict error. + mock.RebaseFn = func(_ context.Context, _, branch string) error { + return &driver.RebaseConflictError{Branch: branch, Detail: "CONFLICT in file.go"} + } + + err := runTier(t, "sync") + + // Should return ExitError with code 2. + if err == nil { + t.Fatal("expected error from sync with conflict") + } + exitErr, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError, got %T: %v", err, err) + } + if exitErr.Code != 2 { + t.Errorf("exit code = %d, want 2", exitErr.Code) + } +} + +func TestResolveDriverFromState(t *testing.T) { + // Empty driver resolves to native. + st := &state.State{Driver: ""} + drv, err := resolveDriver(st) + if err != nil { + t.Fatalf("resolveDriver empty: %v", err) + } + if drv.Name() != "native" { + t.Errorf("Name() = %q, want native", drv.Name()) + } + + // Unknown driver errors. + st = &state.State{Driver: "bogus"} + _, err = resolveDriver(st) + if err == nil { + t.Fatal("expected error for unknown driver in state") + } + + // driverOverride takes precedence. + mock := driver.NewMock() + driverOverride = mock + defer func() { driverOverride = nil }() + + st = &state.State{Driver: "bogus"} // would fail without override + drv, err = resolveDriver(st) + if err != nil { + t.Fatalf("resolveDriver with override: %v", err) + } + if drv.Name() != "mock" { + t.Errorf("Name() = %q, want mock", drv.Name()) + } +} diff --git a/cmd/helpers.go b/cmd/helpers.go index 20fa40b..c6a0ecd 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -6,9 +6,22 @@ import ( "unicode" "github.com/nvandessel/frond/internal/dag" + "github.com/nvandessel/frond/internal/driver" "github.com/nvandessel/frond/internal/state" ) +// driverOverride is nil in production; tests set it to inject a mock driver. +var driverOverride driver.Driver + +// resolveDriver returns the active driver. If driverOverride is set (tests), +// it is returned directly. Otherwise the driver is resolved from state. +func resolveDriver(st *state.State) (driver.Driver, error) { + if driverOverride != nil { + return driverOverride, nil + } + return driver.Resolve(st.Driver) +} + // validateBranchName checks that a branch name is safe to use with git commands. func validateBranchName(name string) error { if name == "" { diff --git a/cmd/init.go b/cmd/init.go new file mode 100644 index 0000000..ffd3395 --- /dev/null +++ b/cmd/init.go @@ -0,0 +1,73 @@ +package cmd + +import ( + "fmt" + + "github.com/nvandessel/frond/internal/driver" + "github.com/nvandessel/frond/internal/state" + "github.com/spf13/cobra" +) + +var initCmd = &cobra.Command{ + Use: "init", + Short: "Initialize frond state with an optional driver", + Example: ` # Initialize with the default native driver + frond init + + # Initialize with the Graphite driver + frond init --driver graphite`, + RunE: runInit, +} + +func init() { + initCmd.Flags().String("driver", "", "Driver to use: native (default), graphite") + rootCmd.AddCommand(initCmd) +} + +func runInit(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + driverName, _ := cmd.Flags().GetString("driver") + + // Validate the driver is known and its CLI is available. + drv, err := driver.Resolve(driverName) + if err != nil { + return err + } + + // Lock state. + unlock, err := state.Lock(ctx) + if err != nil { + return fmt.Errorf("acquiring lock: %w", err) + } + defer unlock() + + // ReadOrInit creates state if needed. + s, err := state.ReadOrInit(ctx) + if err != nil { + return fmt.Errorf("reading state: %w", err) + } + + // Set the driver if specified (or clear to native default). + if driverName == "native" { + driverName = "" + } + s.Driver = driverName + + if err := state.Write(ctx, s); err != nil { + return fmt.Errorf("writing state: %w", err) + } + + if jsonOut { + return printJSON(initResult{ + Driver: drv.Name(), + Trunk: s.Trunk, + }) + } + fmt.Printf("Initialized frond (driver: %s, trunk: %s)\n", drv.Name(), s.Trunk) + return nil +} + +type initResult struct { + Driver string `json:"driver"` + Trunk string `json:"trunk"` +} diff --git a/cmd/new.go b/cmd/new.go index c1a8814..47fdbce 100644 --- a/cmd/new.go +++ b/cmd/new.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - "github.com/nvandessel/frond/internal/git" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -51,8 +50,14 @@ func runNew(cmd *cobra.Command, args []string) error { return fmt.Errorf("reading state: %w", err) } + // 3. Resolve driver + drv, err := resolveDriver(s) + if err != nil { + return err + } + // Check if branch already exists in git - exists, err := git.BranchExists(ctx, name) + exists, err := drv.BranchExists(ctx, name) if err != nil { return fmt.Errorf("checking branch existence: %w", err) } @@ -60,13 +65,13 @@ func runNew(cmd *cobra.Command, args []string) error { return fmt.Errorf("branch '%s' already exists. Use 'frond track' to add it", name) } - // 3. Resolve parent: --on flag -> current branch if tracked -> trunk + // 4. Resolve parent: --on flag -> current branch if tracked -> trunk onFlag, _ := cmd.Flags().GetString("on") parent := s.Trunk if onFlag != "" { parent = onFlag } else { - current, err := git.CurrentBranch(ctx) + current, err := drv.CurrentBranch(ctx) if err == nil { if _, tracked := s.Branches[current]; tracked { parent = current @@ -74,15 +79,15 @@ func runNew(cmd *cobra.Command, args []string) error { } } - // 4. Parse --after + // 5. Parse --after afterFlag, _ := cmd.Flags().GetString("after") var after []string if afterFlag != "" { after = strings.Split(afterFlag, ",") } - // 5. Validate parent branch exists in git - parentExists, err := git.BranchExists(ctx, parent) + // 6. Validate parent branch exists in git + parentExists, err := drv.BranchExists(ctx, parent) if err != nil { return fmt.Errorf("checking parent branch: %w", err) } @@ -90,17 +95,17 @@ func runNew(cmd *cobra.Command, args []string) error { return fmt.Errorf("parent branch '%s' does not exist", parent) } - // 6. Validate --after deps and check for cycles + // 7. Validate --after deps and check for cycles if err := validateAfterDeps(s.Branches, name, after); err != nil { return err } - // 7. git.CreateBranch (also checks it out) - if err := git.CreateBranch(ctx, name, parent); err != nil { + // 8. Create branch (also checks it out) + if err := drv.CreateBranch(ctx, name, parent); err != nil { return fmt.Errorf("creating branch: %w", err) } - // 7. Write branch to state.Branches + // 9. Write branch to state.Branches if after == nil { after = []string{} } @@ -109,12 +114,12 @@ func runNew(cmd *cobra.Command, args []string) error { After: after, } - // 8. Write state + // 10. Write state if err := state.Write(ctx, s); err != nil { return fmt.Errorf("writing state: %w", err) } - // 9. Output + // 11. Output if jsonOut { return printJSON(newResult{ Name: name, diff --git a/cmd/push.go b/cmd/push.go index ad61f04..5cc616e 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -6,8 +6,7 @@ import ( "strings" "unicode" - "github.com/nvandessel/frond/internal/gh" - "github.com/nvandessel/frond/internal/git" + "github.com/nvandessel/frond/internal/driver" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -51,88 +50,73 @@ func humanizeTitle(branch string) string { func runPush(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // 1. Check gh is available. - if err := gh.Available(); err != nil { - return fmt.Errorf("gh CLI is required. Install: https://cli.github.com") - } - - // 2. Get current branch. - branch, err := git.CurrentBranch(ctx) - if err != nil { - return fmt.Errorf("getting current branch: %w", err) - } - - // 3. Lock state, defer unlock. + // 1. Lock state, defer unlock. unlock, err := state.Lock(ctx) if err != nil { return fmt.Errorf("acquiring lock: %w", err) } defer unlock() - // 4. Read state (not ReadOrInit). + // 2. Read state (not ReadOrInit). st, err := state.Read(ctx) if err != nil { return fmt.Errorf("reading state: %w", err) } + // 3. Resolve driver. + drv, err := resolveDriver(st) + if err != nil { + return err + } + + // 4. Get current branch. + branch, err := drv.CurrentBranch(ctx) + if err != nil { + return fmt.Errorf("getting current branch: %w", err) + } + // 5. Current branch must be tracked. br, ok := st.Branches[branch] if !ok { return fmt.Errorf("current branch '%s' is not tracked", branch) } - // 6. Push to origin. - if err := git.Push(ctx, branch); err != nil { - return fmt.Errorf("pushing to origin: %w", err) + // 6. Build push opts. + title, _ := cmd.Flags().GetString("title") + if title == "" { + title = humanizeTitle(branch) + } + body, _ := cmd.Flags().GetString("body") + draft, _ := cmd.Flags().GetBool("draft") + + opts := driver.PushOpts{ + Branch: branch, + Base: br.Parent, + Title: title, + Body: body, + Draft: draft, + ExistingPR: br.PR, } - created := false - var prNumber int - - // 7. If no PR exists, create one. - if br.PR == nil { - title, _ := cmd.Flags().GetString("title") - if title == "" { - title = humanizeTitle(branch) - } - body, _ := cmd.Flags().GetString("body") - draft, _ := cmd.Flags().GetBool("draft") - - prNumber, err = gh.PRCreate(ctx, gh.PRCreateOpts{ - Base: br.Parent, - Head: branch, - Title: title, - Body: body, - Draft: draft, - }) - if err != nil { - return fmt.Errorf("creating PR: %w", err) - } + // 7. Push (creates or updates PR). + result, err := drv.Push(ctx, opts) + if err != nil { + return fmt.Errorf("pushing: %w", err) + } - br.PR = &prNumber + // 8. Write PR number to state if created. + if result.Created { + br.PR = &result.PRNumber st.Branches[branch] = br if err := state.Write(ctx, st); err != nil { return fmt.Errorf("writing state: %w", err) } - created = true - } else { - // 8. PR exists — check if base needs retargeting. - prNumber = *br.PR - - info, err := gh.PRView(ctx, prNumber) - if err != nil { - return fmt.Errorf("viewing PR #%d: %w", prNumber, err) - } - - if info.BaseRefName != br.Parent { - if err := gh.PREdit(ctx, prNumber, br.Parent); err != nil { - return fmt.Errorf("retargeting PR #%d: %w", prNumber, err) - } - } } - // 9. Update stack comments on all PRs. - updateStackComments(ctx, st) + // 9. Update stack comments on all PRs (skip for drivers that manage their own). + if drv.SupportsStackComments() { + updateStackComments(ctx, st) + } // 10. Check for unmet --after deps: warn if any are still tracked. if len(br.After) > 0 { @@ -151,15 +135,15 @@ func runPush(cmd *cobra.Command, args []string) error { if jsonOut { return printJSON(pushResult{ Branch: branch, - PR: prNumber, - Created: created, + PR: result.PRNumber, + Created: result.Created, }) } action := "updated" - if created { + if result.Created { action = "created" } - fmt.Printf("Pushed %s. PR #%d [%s]\n", branch, prNumber, action) + fmt.Printf("Pushed %s. PR #%d [%s]\n", branch, result.PRNumber, action) return nil } diff --git a/cmd/status.go b/cmd/status.go index 142cf32..a4006aa 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -8,7 +8,7 @@ import ( "slices" "github.com/nvandessel/frond/internal/dag" - "github.com/nvandessel/frond/internal/gh" + "github.com/nvandessel/frond/internal/driver" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -67,10 +67,14 @@ func runStatus(cmd *cobra.Command, args []string) error { readinessMap[ri.Name] = ri } - // 5. If --fetch, get live PR states from GitHub. + // 5. If --fetch, get live PR states. prStates := make(map[string]string) if fetchFlag { - prStates = fetchPRStates(ctx, prNumbers) + drv, err := resolveDriver(s) + if err != nil { + return err + } + prStates = fetchPRStates(ctx, drv, prNumbers) } // 6. Output. @@ -80,20 +84,20 @@ func runStatus(cmd *cobra.Command, args []string) error { return outputHuman(s.Trunk, branches, prNumbers, readinessMap, prStates) } -// fetchPRStates calls gh.PRView for each branch that has a PR number. +// fetchPRStates calls drv.PRState for each branch that has a PR number. // On individual failures it warns to stderr and continues. -func fetchPRStates(ctx context.Context, prNumbers map[string]*int) map[string]string { +func fetchPRStates(ctx context.Context, drv driver.Driver, prNumbers map[string]*int) map[string]string { states := make(map[string]string) for name, pr := range prNumbers { if pr == nil { continue } - info, err := gh.PRView(ctx, *pr) + prState, err := drv.PRState(ctx, *pr) if err != nil { fmt.Fprintf(os.Stderr, "warning: failed to fetch PR #%d for %s: %v\n", *pr, name, err) continue } - states[name] = info.State + states[name] = prState } return states } diff --git a/cmd/sync.go b/cmd/sync.go index 9f2363c..bccbc0f 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -7,8 +7,7 @@ import ( "strings" "github.com/nvandessel/frond/internal/dag" - "github.com/nvandessel/frond/internal/gh" - "github.com/nvandessel/frond/internal/git" + "github.com/nvandessel/frond/internal/driver" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -69,13 +68,19 @@ func runSync(cmd *cobra.Command, args []string) error { return nil } + // Step 2b: Resolve driver. + drv, err := resolveDriver(st) + if err != nil { + return err + } + // Step 3: Fetch from origin. - if err := git.Fetch(ctx); err != nil { + if err := drv.Fetch(ctx); err != nil { return fmt.Errorf("fetching: %w", err) } // Save current branch before any operations so we can restore it. - originalBranch, err := git.CurrentBranch(ctx) + originalBranch, err := drv.CurrentBranch(ctx) if err != nil { return fmt.Errorf("getting current branch: %w", err) } @@ -90,12 +95,12 @@ func runSync(cmd *cobra.Command, args []string) error { if b.PR == nil { continue } - info, err := gh.PRView(ctx, *b.PR) + prState, err := drv.PRState(ctx, *b.PR) if err != nil { fmt.Fprintf(os.Stderr, "warning: could not check PR #%d for %s: %v\n", *b.PR, name, err) continue } - if info.State == gh.PRStateMerged { + if prState == driver.PRStateMerged { mergedBranches = append(mergedBranches, name) mergedData[name] = b } @@ -125,7 +130,7 @@ func runSync(cmd *cobra.Command, args []string) error { // 5b: Update child PRs to point to new parent. if childBranch.PR != nil { - if err := gh.PREdit(ctx, *childBranch.PR, mergedParent); err != nil { + if err := drv.RetargetPR(ctx, *childBranch.PR, mergedParent); err != nil { fmt.Fprintf(os.Stderr, "warning: could not retarget PR #%d for %s: %v\n", *childBranch.PR, childName, err) } } @@ -147,8 +152,9 @@ func runSync(cmd *cobra.Command, args []string) error { return fmt.Errorf("writing state: %w", err) } - // Step 5e: Update stack comments when merges changed the tree structure. - if len(mergedBranches) > 0 { + // Step 5e: Update stack comments when merges changed the tree structure + // (skip for drivers that manage their own stack visualization). + if len(mergedBranches) > 0 && drv.SupportsStackComments() { updateMergedComments(ctx, st, mergedData) updateStackComments(ctx, st) } @@ -181,8 +187,8 @@ func runSync(cmd *cobra.Command, args []string) error { ri := readinessMap[name] if ri.Ready { parent := st.Branches[name].Parent - if err := git.Rebase(ctx, parent, name); err != nil { - var conflictErr *git.RebaseConflictError + if err := drv.Rebase(ctx, parent, name); err != nil { + var conflictErr *driver.RebaseConflictError if errors.As(err, &conflictErr) { conflictBranch = name result.Conflicts = append(result.Conflicts, name) @@ -221,7 +227,7 @@ func runSync(cmd *cobra.Command, args []string) error { // Restore original branch after rebasing. if len(result.Rebased) > 0 || conflictBranch != "" { - if err := git.Checkout(ctx, originalBranch); err != nil { + if err := drv.Checkout(ctx, originalBranch); err != nil { fmt.Fprintf(os.Stderr, "warning: could not restore branch %s: %v\n", originalBranch, err) } } diff --git a/cmd/track.go b/cmd/track.go index 918fb85..433b8d1 100644 --- a/cmd/track.go +++ b/cmd/track.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - "github.com/nvandessel/frond/internal/git" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -49,8 +48,14 @@ func runTrack(cmd *cobra.Command, args []string) error { return fmt.Errorf("reading state: %w", err) } - // 3. Validate branch exists locally - exists, err := git.BranchExists(ctx, name) + // 3. Resolve driver + drv, err := resolveDriver(s) + if err != nil { + return err + } + + // 4. Validate branch exists locally + exists, err := drv.BranchExists(ctx, name) if err != nil { return fmt.Errorf("checking branch existence: %w", err) } @@ -63,12 +68,12 @@ func runTrack(cmd *cobra.Command, args []string) error { return fmt.Errorf("branch '%s' is already tracked", name) } - // 4. Validate --on branch exists (trunk or tracked) + // 5. Validate --on branch exists (trunk or tracked) onFlag, _ := cmd.Flags().GetString("on") if onFlag != s.Trunk { if _, tracked := s.Branches[onFlag]; !tracked { // Also check if branch exists in git at all - onExists, err := git.BranchExists(ctx, onFlag) + onExists, err := drv.BranchExists(ctx, onFlag) if err != nil { return fmt.Errorf("checking parent branch: %w", err) } @@ -80,19 +85,19 @@ func runTrack(cmd *cobra.Command, args []string) error { } parent := onFlag - // 5. Parse --after + // 6. Parse --after afterFlag, _ := cmd.Flags().GetString("after") var after []string if afterFlag != "" { after = strings.Split(afterFlag, ",") } - // 6. Validate --after deps and check for cycles + // 7. Validate --after deps and check for cycles if err := validateAfterDeps(s.Branches, name, after); err != nil { return err } - // 7. Add to state.Branches (no checkout, no git branch creation) + // 8. Add to state.Branches (no checkout, no git branch creation) if after == nil { after = []string{} } @@ -101,12 +106,12 @@ func runTrack(cmd *cobra.Command, args []string) error { After: after, } - // 8. Write state + // 9. Write state if err := state.Write(ctx, s); err != nil { return fmt.Errorf("writing state: %w", err) } - // 9. Output + // 10. Output if jsonOut { return printJSON(trackResult{ Name: name, diff --git a/cmd/untrack.go b/cmd/untrack.go index 96e11b9..235aa83 100644 --- a/cmd/untrack.go +++ b/cmd/untrack.go @@ -3,7 +3,6 @@ package cmd import ( "fmt" - "github.com/nvandessel/frond/internal/git" "github.com/nvandessel/frond/internal/state" "github.com/spf13/cobra" ) @@ -40,19 +39,25 @@ func runUntrack(cmd *cobra.Command, args []string) error { return fmt.Errorf("reading state: %w", err) } - // 3. Resolve branch: arg or current branch + // 3. Resolve driver (needed for CurrentBranch if no arg) + drv, err := resolveDriver(s) + if err != nil { + return err + } + + // 4. Resolve branch: arg or current branch var name string if len(args) > 0 { name = args[0] } else { - current, err := git.CurrentBranch(ctx) + current, err := drv.CurrentBranch(ctx) if err != nil { return fmt.Errorf("getting current branch: %w", err) } name = current } - // 4. Must be tracked + // 5. Must be tracked branch, tracked := s.Branches[name] if !tracked { return fmt.Errorf("branch '%s' is not tracked", name) @@ -60,11 +65,11 @@ func runUntrack(cmd *cobra.Command, args []string) error { removedParent := branch.Parent - // 5. Remove from state.Branches + // 6. Remove from state.Branches delete(s.Branches, name) - // 6. Remove from ALL other branches' after lists - // 7. Reparent children: any branch whose parent was this branch -> set parent to this branch's parent + // 7. Remove from ALL other branches' after lists + // 8. Reparent children: any branch whose parent was this branch -> set parent to this branch's parent var reparented []string var unblocked []string @@ -93,12 +98,12 @@ func runUntrack(cmd *cobra.Command, args []string) error { s.Branches[bName] = b } - // 8. Write state + // 9. Write state if err := state.Write(ctx, s); err != nil { return fmt.Errorf("writing state: %w", err) } - // 9. Output + // 10. Output if jsonOut { if reparented == nil { reparented = []string{} diff --git a/internal/driver/driver.go b/internal/driver/driver.go new file mode 100644 index 0000000..0f6f4e4 --- /dev/null +++ b/internal/driver/driver.go @@ -0,0 +1,82 @@ +// Package driver defines the interface for branch/PR/git operations. +// Frond delegates all external CLI interactions through a Driver so that +// different stacking tools (native git+gh, Graphite, etc.) can be used +// interchangeably while frond manages the DAG layer. +package driver + +import ( + "context" + "fmt" +) + +// Driver abstracts branch creation, pushing, rebasing, and PR management. +type Driver interface { + Name() string + + // Git queries + CurrentBranch(ctx context.Context) (string, error) + BranchExists(ctx context.Context, name string) (bool, error) + Checkout(ctx context.Context, name string) error + + // Branch mutation + CreateBranch(ctx context.Context, name, parent string) error + + // Remote + PR + Fetch(ctx context.Context) error + Push(ctx context.Context, opts PushOpts) (*PushResult, error) + Rebase(ctx context.Context, onto, branch string) error + PRState(ctx context.Context, prNumber int) (string, error) + RetargetPR(ctx context.Context, prNumber int, newBase string) error + + // SupportsStackComments reports whether frond should post/update + // stack comments on PRs. Drivers like Graphite manage their own + // stack visualization, so frond skips comment management for them. + SupportsStackComments() bool +} + +// PushOpts configures a push + PR create/update operation. +type PushOpts struct { + Branch string // branch to push + Base string // desired PR base branch + Title string + Body string + Draft bool + // ExistingPR is nil for new PRs; non-nil to push + retarget an existing PR. + ExistingPR *int +} + +// PushResult is returned after a successful push. +type PushResult struct { + PRNumber int + Created bool +} + +// RebaseConflictError is returned when a rebase fails due to merge conflicts. +type RebaseConflictError struct { + Branch string + Detail string +} + +func (e *RebaseConflictError) Error() string { + return fmt.Sprintf("rebase conflict on branch %s: %s", e.Branch, e.Detail) +} + +// PR state constants returned by PRState. +const ( + PRStateOpen = "OPEN" + PRStateClosed = "CLOSED" + PRStateMerged = "MERGED" +) + +// Resolve returns the Driver for the given driver name. +// An empty name resolves to the native (git+gh) driver. +func Resolve(name string) (Driver, error) { + switch name { + case "", "native": + return NewNative() + case "graphite": + return NewGraphite() + default: + return nil, fmt.Errorf("unknown driver %q (supported: native, graphite)", name) + } +} diff --git a/internal/driver/driver_test.go b/internal/driver/driver_test.go new file mode 100644 index 0000000..f4e11c3 --- /dev/null +++ b/internal/driver/driver_test.go @@ -0,0 +1,161 @@ +package driver + +import ( + "context" + "os/exec" + "testing" +) + +func TestResolveNative(t *testing.T) { + _, err := exec.LookPath("gh") + if err != nil { + // gh not installed — Resolve should fail with a descriptive error. + _, resolveErr := Resolve("") + if resolveErr == nil { + t.Fatal("Resolve('') should fail when gh is not installed") + } + t.Logf("gh not installed, Resolve error: %v (expected)", resolveErr) + return + } + + drv, err := Resolve("") + if err != nil { + t.Fatalf("Resolve empty: %v", err) + } + if drv.Name() != "native" { + t.Errorf("Name() = %q, want %q", drv.Name(), "native") + } + + drv, err = Resolve("native") + if err != nil { + t.Fatalf("Resolve native: %v", err) + } + if drv.Name() != "native" { + t.Errorf("Name() = %q, want %q", drv.Name(), "native") + } +} + +func TestResolveGraphite(t *testing.T) { + _, err := exec.LookPath("gt") + if err != nil { + // gt not installed — Resolve should fail with a descriptive error. + _, resolveErr := Resolve("graphite") + if resolveErr == nil { + t.Fatal("Resolve(graphite) should fail when gt is not installed") + } + t.Logf("gt not installed, Resolve error: %v (expected)", resolveErr) + return + } + + // gt is installed — Resolve should succeed. + drv, err := Resolve("graphite") + if err != nil { + t.Fatalf("Resolve graphite: %v", err) + } + if drv.Name() != "graphite" { + t.Errorf("Name() = %q, want %q", drv.Name(), "graphite") + } +} + +func TestResolveUnknown(t *testing.T) { + _, err := Resolve("bogus") + if err == nil { + t.Fatal("expected error for unknown driver") + } +} + +func TestMockBasicFlow(t *testing.T) { + ctx := context.Background() + m := NewMock() + + // Initial state. + br, _ := m.CurrentBranch(ctx) + if br != "main" { + t.Errorf("initial branch = %q, want main", br) + } + + // Create branch. + if err := m.CreateBranch(ctx, "feature", "main"); err != nil { + t.Fatalf("CreateBranch: %v", err) + } + br, _ = m.CurrentBranch(ctx) + if br != "feature" { + t.Errorf("after create, branch = %q, want feature", br) + } + exists, _ := m.BranchExists(ctx, "feature") + if !exists { + t.Error("feature should exist") + } + + // Checkout. + if err := m.Checkout(ctx, "main"); err != nil { + t.Fatalf("Checkout: %v", err) + } + br, _ = m.CurrentBranch(ctx) + if br != "main" { + t.Errorf("after checkout, branch = %q, want main", br) + } + + // Push (default). + result, err := m.Push(ctx, PushOpts{Branch: "feature", Base: "main"}) + if err != nil { + t.Fatalf("Push: %v", err) + } + if !result.Created { + t.Error("expected Created=true for new PR") + } + + // Push with ExistingPR. + pr := 42 + result, err = m.Push(ctx, PushOpts{Branch: "feature", Base: "main", ExistingPR: &pr}) + if err != nil { + t.Fatalf("Push existing: %v", err) + } + if result.Created { + t.Error("expected Created=false for existing PR") + } + + // Fetch, Rebase, PRState, RetargetPR — defaults are no-ops. + if err := m.Fetch(ctx); err != nil { + t.Fatalf("Fetch: %v", err) + } + if err := m.Rebase(ctx, "main", "feature"); err != nil { + t.Fatalf("Rebase: %v", err) + } + state, err := m.PRState(ctx, 42) + if err != nil { + t.Fatalf("PRState: %v", err) + } + if state != "OPEN" { + t.Errorf("PRState = %q, want OPEN", state) + } + if err := m.RetargetPR(ctx, 42, "main"); err != nil { + t.Fatalf("RetargetPR: %v", err) + } +} + +func TestMockOverrides(t *testing.T) { + ctx := context.Background() + m := NewMock() + + fetchCalled := false + m.FetchFn = func(_ context.Context) error { + fetchCalled = true + return nil + } + + if err := m.Fetch(ctx); err != nil { + t.Fatal(err) + } + if !fetchCalled { + t.Error("FetchFn not called") + } +} + +func TestRebaseConflictError(t *testing.T) { + e := &RebaseConflictError{Branch: "feat", Detail: "CONFLICT in file.go"} + got := e.Error() + if got != "rebase conflict on branch feat: CONFLICT in file.go" { + t.Errorf("Error() = %q", got) + } +} diff --git a/internal/driver/graphite.go b/internal/driver/graphite.go new file mode 100644 index 0000000..0955e06 --- /dev/null +++ b/internal/driver/graphite.go @@ -0,0 +1,125 @@ +package driver + +import ( + "context" + "fmt" + "os/exec" + "regexp" + "strconv" + "strings" +) + +// submitLineRe matches gt submit output lines: ": (created|updated)" +var submitLineRe = regexp.MustCompile(`^(\S+):\s+(https://\S+)\s+\((created|updated)\)$`) + +// Graphite delegates stacking operations to the Graphite CLI (gt). +// It embeds Native and overrides CreateBranch, Push, and Rebase. +type Graphite struct { + Native +} + +// NewGraphite validates that gt is installed and returns a Graphite driver. +func NewGraphite() (*Graphite, error) { + if _, err := exec.LookPath("gt"); err != nil { + return nil, fmt.Errorf("graphite CLI (gt) not found. Install: https://graphite.dev/docs/installing-the-cli") + } + return &Graphite{}, nil +} + +func (g *Graphite) Name() string { return "graphite" } +func (g *Graphite) SupportsStackComments() bool { return false } + +func (g *Graphite) CreateBranch(ctx context.Context, name, parent string) error { + // Checkout parent first, then use gt create. + if err := g.Checkout(ctx, parent); err != nil { + return fmt.Errorf("checking out parent %s: %w", parent, err) + } + out, err := runGT(ctx, "create", name) + if err != nil { + return fmt.Errorf("gt create %s: %s: %w", name, out, err) + } + return nil +} + +func (g *Graphite) Push(ctx context.Context, opts PushOpts) (*PushResult, error) { + args := []string{"submit", "--no-interactive", "--no-edit"} + if opts.Draft { + args = append(args, "--draft") + } + if opts.Title != "" && opts.ExistingPR == nil { + args = append(args, "--title", opts.Title) + } + if opts.Body != "" && opts.ExistingPR == nil { + args = append(args, "--description", opts.Body) + } + + out, err := runGT(ctx, args...) + if err != nil { + return nil, fmt.Errorf("gt submit: %s: %w", out, err) + } + + // For existing PRs, return the existing number. + if opts.ExistingPR != nil { + return &PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + + // Parse PR number and created/updated status from gt submit output. + prNum, created, err := parseSubmitResult(out, opts.Branch) + if err != nil { + return nil, fmt.Errorf("parsing PR number from gt submit output: %w", err) + } + return &PushResult{PRNumber: prNum, Created: created}, nil +} + +func (g *Graphite) Rebase(ctx context.Context, _, _ string) error { + // gt restack handles the entire stack; called per-branch in topo loop + // but is idempotent so repeated calls are safe. + out, err := runGT(ctx, "restack") + if err != nil { + if strings.Contains(out, "CONFLICT") || strings.Contains(out, "could not apply") { + return &RebaseConflictError{ + Branch: "stack", + Detail: out, + } + } + return fmt.Errorf("gt restack: %s: %w", out, err) + } + return nil +} + +// runGT executes a gt command and returns combined stdout/stderr. +func runGT(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "gt", args...) + var out strings.Builder + cmd.Stdout = &out + cmd.Stderr = &out + err := cmd.Run() + return strings.TrimSpace(out.String()), err +} + +// parseSubmitResult extracts the PR number and created/updated status for +// branch from gt submit output. +// gt submit prints one line per branch: ": (created|updated)" +func parseSubmitResult(output, branch string) (prNumber int, created bool, err error) { + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + matches := submitLineRe.FindStringSubmatch(line) + if matches == nil { + continue + } + if matches[1] != branch { + continue + } + url := matches[2] + idx := strings.LastIndex(url, "/") + if idx == -1 || idx == len(url)-1 { + return 0, false, fmt.Errorf("malformed PR URL %q: no trailing number", url) + } + num, parseErr := strconv.Atoi(url[idx+1:]) + if parseErr != nil { + return 0, false, fmt.Errorf("malformed PR URL %q: %w", url, parseErr) + } + return num, matches[3] == "created", nil + } + return 0, false, fmt.Errorf("branch %q not found in gt submit output:\n%s", branch, output) +} diff --git a/internal/driver/graphite_test.go b/internal/driver/graphite_test.go new file mode 100644 index 0000000..d8da5d3 --- /dev/null +++ b/internal/driver/graphite_test.go @@ -0,0 +1,372 @@ +package driver + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// fakeBinDir holds the directory containing the built fakegt binary. +// Set by TestMain before any tests run. +var fakeBinDir string + +func TestMain(m *testing.M) { + // Build fakegt and prepend its directory to PATH so that + // exec.LookPath("gt") and runGT find our test double. + dir, err := os.MkdirTemp("", "fakegt-*") + if err != nil { + fmt.Fprintf(os.Stderr, "creating temp dir: %v\n", err) + os.Exit(1) + } + defer os.RemoveAll(dir) + + gtBin := filepath.Join(dir, "gt") + cmd := exec.Command("go", "build", "-o", gtBin, "./testdata/fakegt") + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + fmt.Fprintf(os.Stderr, "building fakegt: %v\n", err) + os.Exit(1) + } + + fakeBinDir = dir + os.Setenv("PATH", dir+":"+os.Getenv("PATH")) + + os.Exit(m.Run()) +} + +// initGitRepo creates a temp git repo with an initial commit, chdir's into it, +// and restores the original directory on cleanup. This is needed because the +// git package operates on the current working directory. +func initGitRepo(t *testing.T) (dir string, ctx context.Context) { + t.Helper() + dir = t.TempDir() + ctx = context.Background() + + gitEnv := []string{ + "GIT_AUTHOR_NAME=Test User", + "GIT_AUTHOR_EMAIL=test@example.com", + "GIT_COMMITTER_NAME=Test User", + "GIT_COMMITTER_EMAIL=test@example.com", + "GIT_CONFIG_NOSYSTEM=1", + "HOME=" + dir, + } + + gitCmd := func(args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + cmd.Env = append(os.Environ(), gitEnv...) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("setup git %s: %s\n%s", strings.Join(args, " "), err, out) + } + } + + gitCmd("init", "-b", "main") + gitCmd("commit", "--allow-empty", "-m", "init") + + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + for _, e := range gitEnv { + parts := strings.SplitN(e, "=", 2) + t.Setenv(parts[0], parts[1]) + } + + return dir, ctx +} + +// --- Unit tests for parseSubmitResult --- + +func TestParseSubmitResult(t *testing.T) { + tests := []struct { + name string + output string + branch string + wantPR int + wantCreated bool + wantErr string + }{ + { + name: "created on .com domain", + output: "my-feature: https://app.graphite.com/github/pr/owner/repo/42 (created)", + branch: "my-feature", + wantPR: 42, + wantCreated: true, + }, + { + name: "updated on .dev domain", + output: "my-feature: https://app.graphite.dev/github/pr/owner/repo/99 (updated)", + branch: "my-feature", + wantPR: 99, + wantCreated: false, + }, + { + name: "multi-branch stack matches correct branch", + output: `pp--06-14-part_1: https://app.graphite.com/github/pr/withgraphite/repo/100 (created) +pp--06-14-part_2: https://app.graphite.com/github/pr/withgraphite/repo/101 (created) +pp--06-14-part_3: https://app.graphite.com/github/pr/withgraphite/repo/102 (created)`, + branch: "pp--06-14-part_2", + wantPR: 101, + wantCreated: true, + }, + { + name: "branch not found", + output: "other-branch: https://app.graphite.com/github/pr/owner/repo/42 (created)", + branch: "my-feature", + wantErr: `branch "my-feature" not found in gt submit output`, + }, + { + name: "malformed URL no trailing number", + output: "my-feature: https://app.graphite.com/github/pr/owner/repo/ (created)", + branch: "my-feature", + wantErr: "malformed PR URL", + }, + { + name: "empty output", + output: "", + branch: "my-feature", + wantErr: `branch "my-feature" not found in gt submit output`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPR, gotCreated, err := parseSubmitResult(tt.output, tt.branch) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %q, want containing %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotPR != tt.wantPR { + t.Errorf("prNumber = %d, want %d", gotPR, tt.wantPR) + } + if gotCreated != tt.wantCreated { + t.Errorf("created = %v, want %v", gotCreated, tt.wantCreated) + } + }) + } +} + +// --- Integration tests using fakegt + temp git repos --- + +func TestNewGraphiteWithFakeGT(t *testing.T) { + // fakegt is on PATH via TestMain, so NewGraphite should succeed. + g, err := NewGraphite() + if err != nil { + t.Fatalf("NewGraphite() with fakegt on PATH: %v", err) + } + if g.Name() != "graphite" { + t.Errorf("Name() = %q, want %q", g.Name(), "graphite") + } +} + +func TestGraphiteCreateBranch(t *testing.T) { + _, ctx := initGitRepo(t) + + g := &Graphite{} + + // CreateBranch checks out the parent (real git) then calls gt create (fakegt). + if err := g.CreateBranch(ctx, "my-feature", "main"); err != nil { + t.Fatalf("CreateBranch: %v", err) + } +} + +func TestGraphitePush(t *testing.T) { + tests := []struct { + name string + submitOut string + branch string + opts PushOpts + wantPR int + wantCreated bool + wantErr string + }{ + { + name: "new PR created", + submitOut: "feat-a: https://app.graphite.com/github/pr/owner/repo/77 (created)", + branch: "feat-a", + opts: PushOpts{ + Branch: "feat-a", + Base: "main", + Title: "Add feature A", + }, + wantPR: 77, + wantCreated: true, + }, + { + name: "existing PR updated by gt", + submitOut: "feat-b: https://app.graphite.com/github/pr/owner/repo/88 (updated)", + branch: "feat-b", + opts: PushOpts{ + Branch: "feat-b", + Base: "main", + }, + wantPR: 88, + wantCreated: false, + }, + { + name: "existing PR via ExistingPR field", + submitOut: "feat-c: https://app.graphite.com/github/pr/owner/repo/99 (updated)", + branch: "feat-c", + opts: PushOpts{ + Branch: "feat-c", + Base: "main", + ExistingPR: intPtr(55), + }, + wantPR: 55, + wantCreated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("FAKEGT_SUBMIT_OUTPUT", tt.submitOut) + ctx := context.Background() + g := &Graphite{} + + result, err := g.Push(ctx, tt.opts) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %q, want containing %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("Push: %v", err) + } + if result.PRNumber != tt.wantPR { + t.Errorf("PRNumber = %d, want %d", result.PRNumber, tt.wantPR) + } + if result.Created != tt.wantCreated { + t.Errorf("Created = %v, want %v", result.Created, tt.wantCreated) + } + }) + } +} + +func TestGraphitePushPassesBodyAndTitle(t *testing.T) { + // Verify that --title and --description flags are passed to gt submit. + recordFile := filepath.Join(t.TempDir(), "record.txt") + t.Setenv("FAKEGT_RECORD", recordFile) + t.Setenv("FAKEGT_SUBMIT_OUTPUT", "my-feat: https://app.graphite.com/github/pr/o/r/1 (created)") + + ctx := context.Background() + g := &Graphite{} + + _, err := g.Push(ctx, PushOpts{ + Branch: "my-feat", + Base: "main", + Title: "My title", + Body: "My description", + }) + if err != nil { + t.Fatalf("Push: %v", err) + } + + recorded, err := os.ReadFile(recordFile) + if err != nil { + t.Fatalf("reading record file: %v", err) + } + args := string(recorded) + + if !strings.Contains(args, "--title My title") { + t.Errorf("expected --title in args, got: %s", args) + } + if !strings.Contains(args, "--description My description") { + t.Errorf("expected --description in args, got: %s", args) + } +} + +func TestGraphitePushDraft(t *testing.T) { + recordFile := filepath.Join(t.TempDir(), "record.txt") + t.Setenv("FAKEGT_RECORD", recordFile) + t.Setenv("FAKEGT_SUBMIT_OUTPUT", "my-feat: https://app.graphite.com/github/pr/o/r/1 (created)") + + ctx := context.Background() + g := &Graphite{} + + _, err := g.Push(ctx, PushOpts{ + Branch: "my-feat", + Base: "main", + Title: "Draft PR", + Draft: true, + }) + if err != nil { + t.Fatalf("Push: %v", err) + } + + recorded, err := os.ReadFile(recordFile) + if err != nil { + t.Fatalf("reading record file: %v", err) + } + if !strings.Contains(string(recorded), "--draft") { + t.Errorf("expected --draft in args, got: %s", recorded) + } +} + +func TestGraphitePushFailure(t *testing.T) { + t.Setenv("FAKEGT_FAIL", "1") + ctx := context.Background() + g := &Graphite{} + + _, err := g.Push(ctx, PushOpts{Branch: "feat", Base: "main"}) + if err == nil { + t.Fatal("expected error when gt submit fails") + } + if !strings.Contains(err.Error(), "gt submit") { + t.Errorf("error = %q, want containing 'gt submit'", err.Error()) + } +} + +func TestGraphiteRebase(t *testing.T) { + ctx := context.Background() + g := &Graphite{} + + // Normal restack succeeds. + if err := g.Rebase(ctx, "main", "feature"); err != nil { + t.Fatalf("Rebase: %v", err) + } +} + +func TestGraphiteRebaseConflict(t *testing.T) { + t.Setenv("FAKEGT_CONFLICT", "1") + ctx := context.Background() + g := &Graphite{} + + err := g.Rebase(ctx, "main", "feature") + if err == nil { + t.Fatal("expected error on conflict") + } + + var conflictErr *RebaseConflictError + if !errors.As(err, &conflictErr) { + t.Fatalf("expected RebaseConflictError, got %T: %v", err, err) + } + if !strings.Contains(conflictErr.Detail, "CONFLICT") { + t.Errorf("Detail = %q, want containing 'CONFLICT'", conflictErr.Detail) + } +} + +func intPtr(n int) *int { return &n } diff --git a/internal/driver/mock.go b/internal/driver/mock.go new file mode 100644 index 0000000..97f5719 --- /dev/null +++ b/internal/driver/mock.go @@ -0,0 +1,90 @@ +package driver + +import ( + "context" + "fmt" +) + +// Mock is a stateful in-memory driver for testing. +// It tracks branches and current branch so multi-step tests work without git. +type Mock struct { + Branches map[string]bool + CurrentBranchName string + StackComments bool // whether SupportsStackComments() returns true + + // Override hooks — nil means use default behavior. + FetchFn func(ctx context.Context) error + PushFn func(ctx context.Context, opts PushOpts) (*PushResult, error) + RebaseFn func(ctx context.Context, onto, branch string) error + PRStateFn func(ctx context.Context, prNumber int) (string, error) + RetargetPRFn func(ctx context.Context, prNumber int, newBase string) error +} + +// NewMock returns a Mock with "main" as the only branch and current branch. +func NewMock() *Mock { + return &Mock{ + Branches: map[string]bool{"main": true}, + CurrentBranchName: "main", + } +} + +func (m *Mock) Name() string { return "mock" } + +func (m *Mock) CurrentBranch(_ context.Context) (string, error) { + return m.CurrentBranchName, nil +} + +func (m *Mock) BranchExists(_ context.Context, name string) (bool, error) { + return m.Branches[name], nil +} + +func (m *Mock) CreateBranch(_ context.Context, name, _ string) error { + m.Branches[name] = true + m.CurrentBranchName = name + return nil +} + +func (m *Mock) Checkout(_ context.Context, name string) error { + if !m.Branches[name] { + return fmt.Errorf("branch %q does not exist", name) + } + m.CurrentBranchName = name + return nil +} + +func (m *Mock) Fetch(ctx context.Context) error { + if m.FetchFn != nil { + return m.FetchFn(ctx) + } + return nil +} + +func (m *Mock) Push(ctx context.Context, opts PushOpts) (*PushResult, error) { + if m.PushFn != nil { + return m.PushFn(ctx, opts) + } + return &PushResult{PRNumber: 1, Created: opts.ExistingPR == nil}, nil +} + +func (m *Mock) Rebase(ctx context.Context, onto, branch string) error { + if m.RebaseFn != nil { + return m.RebaseFn(ctx, onto, branch) + } + return nil +} + +func (m *Mock) PRState(ctx context.Context, prNumber int) (string, error) { + if m.PRStateFn != nil { + return m.PRStateFn(ctx, prNumber) + } + return "OPEN", nil +} + +func (m *Mock) RetargetPR(ctx context.Context, prNumber int, newBase string) error { + if m.RetargetPRFn != nil { + return m.RetargetPRFn(ctx, prNumber, newBase) + } + return nil +} + +func (m *Mock) SupportsStackComments() bool { return m.StackComments } diff --git a/internal/driver/native.go b/internal/driver/native.go new file mode 100644 index 0000000..ea35f93 --- /dev/null +++ b/internal/driver/native.go @@ -0,0 +1,102 @@ +package driver + +import ( + "context" + "errors" + "fmt" + + "github.com/nvandessel/frond/internal/gh" + "github.com/nvandessel/frond/internal/git" +) + +// Native is the default driver using git + gh CLIs directly. +type Native struct{} + +// NewNative validates that gh is installed and returns a Native driver. +func NewNative() (*Native, error) { + if err := gh.Available(); err != nil { + return nil, err + } + return &Native{}, nil +} + +func (n *Native) Name() string { return "native" } + +func (n *Native) CurrentBranch(ctx context.Context) (string, error) { + return git.CurrentBranch(ctx) +} + +func (n *Native) BranchExists(ctx context.Context, name string) (bool, error) { + return git.BranchExists(ctx, name) +} + +func (n *Native) Checkout(ctx context.Context, name string) error { + return git.Checkout(ctx, name) +} + +func (n *Native) CreateBranch(ctx context.Context, name, parent string) error { + return git.CreateBranch(ctx, name, parent) +} + +func (n *Native) Fetch(ctx context.Context) error { + return git.Fetch(ctx) +} + +func (n *Native) Push(ctx context.Context, opts PushOpts) (*PushResult, error) { + // Push the branch to origin. + if err := git.Push(ctx, opts.Branch); err != nil { + return nil, fmt.Errorf("pushing %s: %w", opts.Branch, err) + } + + if opts.ExistingPR != nil { + // Existing PR — check if base needs retargeting. + info, err := gh.PRView(ctx, *opts.ExistingPR) + if err != nil { + return nil, fmt.Errorf("viewing PR #%d: %w", *opts.ExistingPR, err) + } + if info.BaseRefName != opts.Base { + if err := gh.PREdit(ctx, *opts.ExistingPR, opts.Base); err != nil { + return nil, fmt.Errorf("retargeting PR #%d: %w", *opts.ExistingPR, err) + } + } + return &PushResult{PRNumber: *opts.ExistingPR, Created: false}, nil + } + + // New PR — create it. + prNum, err := gh.PRCreate(ctx, gh.PRCreateOpts{ + Base: opts.Base, + Head: opts.Branch, + Title: opts.Title, + Body: opts.Body, + Draft: opts.Draft, + }) + if err != nil { + return nil, fmt.Errorf("creating PR: %w", err) + } + return &PushResult{PRNumber: prNum, Created: true}, nil +} + +func (n *Native) Rebase(ctx context.Context, onto, branch string) error { + err := git.Rebase(ctx, onto, branch) + if err != nil { + var conflictErr *git.RebaseConflictError + if errors.As(err, &conflictErr) { + return &RebaseConflictError{ + Branch: conflictErr.Branch, + Detail: conflictErr.Stderr, + } + } + return err + } + return nil +} + +func (n *Native) PRState(ctx context.Context, prNumber int) (string, error) { + return gh.PRState(ctx, prNumber) +} + +func (n *Native) RetargetPR(ctx context.Context, prNumber int, newBase string) error { + return gh.PREdit(ctx, prNumber, newBase) +} + +func (n *Native) SupportsStackComments() bool { return true } diff --git a/internal/driver/testdata/fakegt/main.go b/internal/driver/testdata/fakegt/main.go new file mode 100644 index 0000000..fbeb927 --- /dev/null +++ b/internal/driver/testdata/fakegt/main.go @@ -0,0 +1,63 @@ +// Command fakegt is a test double for the Graphite CLI (gt). +// Behavior is controlled via environment variables: +// +// - FAKEGT_FAIL: if set, exit 1 with error message +// - FAKEGT_CONFLICT: if set, exit 1 with CONFLICT output (for restack) +// - FAKEGT_SUBMIT_OUTPUT: custom stdout for "submit" command +// - FAKEGT_RECORD: if set to a file path, append each invocation's args +package main + +import ( + "fmt" + "os" + "strings" +) + +func main() { + args := os.Args[1:] + + // Record invocations for test assertions. + if recordFile := os.Getenv("FAKEGT_RECORD"); recordFile != "" { + f, err := os.OpenFile(recordFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err == nil { + fmt.Fprintln(f, strings.Join(args, " ")) + f.Close() + } + } + + // Unconditional failure mode. + if os.Getenv("FAKEGT_FAIL") != "" { + fmt.Fprintln(os.Stderr, "fatal: something went wrong") + os.Exit(1) + } + + if len(args) == 0 { + os.Exit(0) + } + + switch args[0] { + case "create": + // gt create — no output on success. + case "submit": + // Conflict mode for submit. + if os.Getenv("FAKEGT_CONFLICT") != "" { + fmt.Fprintln(os.Stderr, "CONFLICT (content): Merge conflict in file.go") + os.Exit(1) + } + // Custom output or default. + if out := os.Getenv("FAKEGT_SUBMIT_OUTPUT"); out != "" { + fmt.Println(out) + } else { + fmt.Println("default-branch: https://app.graphite.com/github/pr/owner/repo/1 (created)") + } + case "restack": + if os.Getenv("FAKEGT_CONFLICT") != "" { + fmt.Println("CONFLICT (content): Merge conflict in file.go") + fmt.Fprintln(os.Stderr, "could not apply abc1234... commit message") + os.Exit(1) + } + fmt.Println("Restacked") + default: + // Unknown commands succeed silently. + } +} diff --git a/internal/state/state.go b/internal/state/state.go index 4ed3af2..a5de908 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -26,6 +26,7 @@ type Branch struct { type State struct { Version int `json:"version"` Trunk string `json:"trunk"` + Driver string `json:"driver,omitempty"` Branches map[string]Branch `json:"branches"` } @@ -41,8 +42,8 @@ const ( stateVersion = 1 ) -// gitCommonDir is a package-level variable so tests can override it. -var gitCommonDir = func(ctx context.Context) (string, error) { +// GitCommonDir is a package-level variable so tests can override it. +var GitCommonDir = func(ctx context.Context) (string, error) { dir, err := git.CommonDir(ctx) if err != nil { return "", err @@ -56,7 +57,7 @@ var gitCommonDir = func(ctx context.Context) (string, error) { // Path returns the absolute path to frond.json. func Path(ctx context.Context) (string, error) { - dir, err := gitCommonDir(ctx) + dir, err := GitCommonDir(ctx) if err != nil { return "", err } @@ -136,7 +137,7 @@ func Write(ctx context.Context, s *State) error { // if err != nil { ... } // defer unlock() func Lock(ctx context.Context) (unlock func(), err error) { - dir, err := gitCommonDir(ctx) + dir, err := GitCommonDir(ctx) if err != nil { return noop, err } diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 35c4016..3c71a13 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -15,8 +15,8 @@ import ( ) // setupGitRepo creates a minimal git repo in a temp dir and overrides -// gitCommonDir to point there. It returns the git-common-dir path and a -// cleanup function that restores the original gitCommonDir. +// GitCommonDir to point there. It returns the git-common-dir path and a +// cleanup function that restores the original GitCommonDir. func setupGitRepo(t *testing.T) (dir string) { t.Helper() @@ -37,13 +37,13 @@ func setupGitRepo(t *testing.T) (dir string) { gitDir := filepath.Join(dir, ".git") - // Override the package-level gitCommonDir so all functions in this + // Override the package-level GitCommonDir so all functions in this // package resolve paths inside our temp repo. - orig := gitCommonDir - gitCommonDir = func(_ context.Context) (string, error) { + orig := GitCommonDir + GitCommonDir = func(_ context.Context) (string, error) { return gitDir, nil } - t.Cleanup(func() { gitCommonDir = orig }) + t.Cleanup(func() { GitCommonDir = orig }) return dir } @@ -285,12 +285,12 @@ func TestReadOrInitMasterBranch(t *testing.T) { run(t, dir, "git", "branch", "-M", "master") // Override detectTrunk's git commands to run inside our temp dir. - origGitCommonDir := gitCommonDir + origGitCommonDir := GitCommonDir gitDir := filepath.Join(dir, ".git") - gitCommonDir = func(_ context.Context) (string, error) { + GitCommonDir = func(_ context.Context) (string, error) { return gitDir, nil } - t.Cleanup(func() { gitCommonDir = origGitCommonDir }) + t.Cleanup(func() { GitCommonDir = origGitCommonDir }) // We need detectTrunk to run git commands in the right repo. // Override the PATH-relative git commands by changing to the dir. @@ -372,11 +372,11 @@ func TestWriteCreatesParentDirs(t *testing.T) { tmpDir := t.TempDir() nestedDir := filepath.Join(tmpDir, "deeply", "nested", "gitdir") - orig := gitCommonDir - gitCommonDir = func(_ context.Context) (string, error) { + orig := GitCommonDir + GitCommonDir = func(_ context.Context) (string, error) { return nestedDir, nil } - t.Cleanup(func() { gitCommonDir = orig }) + t.Cleanup(func() { GitCommonDir = orig }) ctx := context.Background() s := &State{ @@ -407,11 +407,11 @@ func TestWriteReadOnlyDir(t *testing.T) { t.Fatal(err) } - orig := gitCommonDir - gitCommonDir = func(_ context.Context) (string, error) { + orig := GitCommonDir + GitCommonDir = func(_ context.Context) (string, error) { return roDir, nil } - t.Cleanup(func() { gitCommonDir = orig }) + t.Cleanup(func() { GitCommonDir = orig }) // Make it read-only AFTER creating the dir. os.Chmod(roDir, 0o555) @@ -452,18 +452,18 @@ func TestLockDoubleLockFails(t *testing.T) { } func TestPathError(t *testing.T) { - // Override gitCommonDir to return an error. - orig := gitCommonDir - gitCommonDir = func(_ context.Context) (string, error) { + // Override GitCommonDir to return an error. + orig := GitCommonDir + GitCommonDir = func(_ context.Context) (string, error) { return "", fmt.Errorf("git not found") } - t.Cleanup(func() { gitCommonDir = orig }) + t.Cleanup(func() { GitCommonDir = orig }) ctx := context.Background() _, err := Path(ctx) if err == nil { - t.Fatal("Path() should fail when gitCommonDir fails") + t.Fatal("Path() should fail when GitCommonDir fails") } _, err = Read(ctx) @@ -478,7 +478,7 @@ func TestPathError(t *testing.T) { _, err = Lock(ctx) if err == nil { - t.Fatal("Lock() should fail when gitCommonDir fails") + t.Fatal("Lock() should fail when GitCommonDir fails") } } @@ -509,3 +509,41 @@ func TestReadOrInitExistingState(t *testing.T) { t.Error("ReadOrInit() re-initialized instead of reading existing state") } } + +func TestDriverField(t *testing.T) { + setupGitRepo(t) + ctx := context.Background() + + // Write state with driver field. + s := &State{ + Version: 1, + Trunk: "main", + Driver: "graphite", + Branches: map[string]Branch{}, + } + if err := Write(ctx, s); err != nil { + t.Fatalf("Write: %v", err) + } + + got, err := Read(ctx) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got.Driver != "graphite" { + t.Errorf("Driver = %q, want %q", got.Driver, "graphite") + } + + // Empty driver should omit from JSON. + s.Driver = "" + if err := Write(ctx, s); err != nil { + t.Fatalf("Write: %v", err) + } + + got, err = Read(ctx) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got.Driver != "" { + t.Errorf("Driver = %q, want empty", got.Driver) + } +}