diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index e5ceda4..5be1729 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -19,46 +19,151 @@ import ( "flag" "fmt" "log" + "os" + "path/filepath" + "sort" + "strings" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" + + "github.com/goforj/wire/internal/cachepaths" +) + +const ( + loaderArtifactDirEnv = cachepaths.LoaderArtifactDirEnv + outputCacheDirEnv = cachepaths.OutputCacheDirEnv ) type cacheCmd struct { clear bool } -// Name returns the subcommand name. +type cacheTarget struct { + name string + path string +} + func (*cacheCmd) Name() string { return "cache" } -// Synopsis returns a short summary of the subcommand. func (*cacheCmd) Synopsis() string { return "inspect or clear the wire cache" } -// Usage returns the help text for the subcommand. func (*cacheCmd) Usage() string { - return `cache [-clear] + return `cache +cache clear +cache -clear - By default, prints the cache directory. With -clear, removes all cache files. + By default, prints the cache directory. With -clear or clear, removes all + Wire-managed cache files. ` } -// SetFlags registers flags for the subcommand. func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { - f.BoolVar(&cmd.clear, "clear", false, "remove all cached data") + f.BoolVar(&cmd.clear, "clear", false, "clear Wire caches") } -// Execute runs the subcommand. func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if cmd.clear { - if err := wire.ClearCache(); err != nil { - log.Printf("failed to clear cache: %v\n", err) - return subcommands.ExitFailure + _ = ctx + clearRequested := cmd.clear + switch extra := f.Args(); len(extra) { + case 0: + if !clearRequested { + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + fmt.Fprintln(os.Stdout, root) + return subcommands.ExitSuccess + } + case 1: + if extra[0] == "clear" { + clearRequested = true + break } - log.Printf("cleared cache at %s\n", wire.CacheDir()) + log.Printf("unknown cache action %q", extra[0]) + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + default: + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + if !clearRequested { + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + cleared, err := clearWireCaches(os.Environ()) + if err != nil { + log.Printf("failed to clear cache: %v\n", err) + return subcommands.ExitFailure + } + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + if len(cleared) == 0 { + log.Printf("cleared cache at %s\n", root) return subcommands.ExitSuccess } - fmt.Println(wire.CacheDir()) + log.Printf("cleared cache at %s\n", root) return subcommands.ExitSuccess } + +func wireCacheRoot(env []string) (string, error) { + root, err := cachepaths.Root(env) + if err != nil { + return "", fmt.Errorf("resolve user cache dir: %w", err) + } + return root, nil +} + +func clearWireCaches(env []string) ([]string, error) { + base, err := wireCacheRoot(env) + if err != nil { + return nil, err + } + targets := wireCacheTargets(env, filepath.Dir(base)) + cleared := make([]string, 0, len(targets)) + for _, target := range targets { + info, err := os.Stat(target.path) + if os.IsNotExist(err) { + continue + } + if err != nil { + return cleared, fmt.Errorf("stat %s cache: %w", target.name, err) + } + if !info.IsDir() { + if err := os.Remove(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + } else if err := os.RemoveAll(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + cleared = append(cleared, target.name) + } + return cleared, nil +} + +func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { + baseWire := cachepaths.EnvValueDefault(env, cachepaths.BaseDirEnv, filepath.Join(userCacheDir, "wire")) + targets := []cacheTarget{ + {name: "loader-artifacts", path: cachepaths.EnvValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, + {name: "discovery-cache", path: cachepaths.EnvValueDefault(env, cachepaths.DiscoveryCacheDirEnv, filepath.Join(baseWire, "discovery-cache"))}, + {name: "output-cache", path: cachepaths.EnvValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, + } + seen := make(map[string]bool, len(targets)) + deduped := make([]cacheTarget, 0, len(targets)) + for _, target := range targets { + cleaned := filepath.Clean(target.path) + if seen[cleaned] { + continue + } + seen[cleaned] = true + target.path = cleaned + deduped = append(deduped, target) + } + sort.Slice(deduped, func(i, j int) bool { return deduped[i].name < deduped[j].name }) + return deduped +} diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go new file mode 100644 index 0000000..1d13acb --- /dev/null +++ b/cmd/wire/cache_cmd_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/goforj/wire/internal/cachepaths" +) + +func TestWireCacheTargetsDefault(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + got := wireCacheTargets(nil, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), + } + if len(got) != len(want) { + t.Fatalf("targets len = %d, want %d", len(got), len(want)) + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestWireCacheRoot(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + old := cachepaths.UserCacheDir + cachepaths.UserCacheDir = func() (string, error) { return base, nil } + defer func() { cachepaths.UserCacheDir = old }() + + got, err := wireCacheRoot(nil) + if err != nil { + t.Fatalf("wireCacheRoot() error = %v", err) + } + want := filepath.Join(base, "wire") + if got != want { + t.Fatalf("wireCacheRoot() = %q, want %q", got, want) + } +} + +func TestWireCacheTargetsRespectOverrides(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + cachepaths.DiscoveryCacheDirEnv + "=" + filepath.Join(base, "discovery"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + } + got := wireCacheTargets(env, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "discovery"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestWireCacheTargetsRespectBaseDirOverride(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + root := filepath.Join(base, "wire-root") + env := []string{cachepaths.BaseDirEnv + "=" + root} + got := wireCacheTargets(env, base) + want := map[string]string{ + "discovery-cache": filepath.Join(root, "discovery-cache"), + "loader-artifacts": filepath.Join(root, "loader-artifacts"), + "output-cache": filepath.Join(root, "output-cache"), + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestClearWireCachesRemovesTargets(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + } + for _, target := range wireCacheTargets(env, base) { + if err := os.MkdirAll(target.path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q): %v", target.path, err) + } + if err := os.WriteFile(filepath.Join(target.path, "marker"), []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile(%q): %v", target.path, err) + } + } + old := cachepaths.UserCacheDir + cachepaths.UserCacheDir = func() (string, error) { return base, nil } + defer func() { cachepaths.UserCacheDir = old }() + + cleared, err := clearWireCaches(env) + if err != nil { + t.Fatalf("clearWireCaches() error = %v", err) + } + if len(cleared) != 3 { + t.Fatalf("cleared len = %d, want 3", len(cleared)) + } + for _, target := range wireCacheTargets(env, base) { + if _, err := os.Stat(target.path); !os.IsNotExist(err) { + t.Fatalf("%s still exists after clear, stat err = %v", target.path, err) + } + } +} diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index 1532dd4..246caa5 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -19,9 +19,7 @@ import ( "flag" "log" "os" - "time" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -66,7 +64,6 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa return subcommands.ExitFailure } defer stop() - totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) wd, err := os.Getwd() @@ -83,42 +80,8 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa opts.PrefixOutputFile = cmd.prefixFileName opts.Tags = cmd.tags - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") + if !runGenerateCommand(ctx, wd, os.Environ(), packages(f), opts, cmd.profile.timings) { return subcommands.ExitFailure } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return subcommands.ExitSuccess - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - // No Wire output. Maybe errors, maybe no Wire directives. - continue - } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return subcommands.ExitFailure - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) return subcommands.ExitSuccess } diff --git a/cmd/wire/generate_runner.go b/cmd/wire/generate_runner.go new file mode 100644 index 0000000..4dc7b20 --- /dev/null +++ b/cmd/wire/generate_runner.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/goforj/wire/internal/wire" +) + +func runGenerateCommand(ctx context.Context, wd string, env []string, patterns []string, opts *wire.GenerateOptions, timings bool) bool { + totalStart := time.Now() + genStart := time.Now() + outs, errs := wire.Generate(ctx, wd, env, patterns, opts) + logTiming(timings, "wire.Generate", genStart) + if len(errs) > 0 { + logErrors(errs) + log.Println("generate failed") + return false + } + if len(outs) == 0 { + logTiming(timings, "total", totalStart) + return true + } + success := true + writeStart := time.Now() + for _, out := range outs { + if len(out.Errs) > 0 { + logErrors(out.Errs) + log.Printf("%s: generate failed\n", out.PkgPath) + success = false + } + if len(out.Content) == 0 { + continue + } + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } + } else { + log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) + success = false + } + } + if !success { + log.Println("at least one generate failure") + return false + } + logTiming(timings, "writes", writeStart) + logTiming(timings, "total", totalStart) + return true +} + +func formatDuration(d time.Duration) string { + return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) +} diff --git a/cmd/wire/logging.go b/cmd/wire/logging.go new file mode 100644 index 0000000..7479f13 --- /dev/null +++ b/cmd/wire/logging.go @@ -0,0 +1,133 @@ +package main + +import ( + "fmt" + "io" + "os" + "strings" +) + +const ( + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " + maxLoggedErrorLines = 5 +) + +func logErrors(errs []error) { + for _, err := range errs { + msg := truncateLoggedError(formatLoggedError(err)) + if strings.Contains(msg, "\n") { + logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) + continue + } + logMultilineError(msg) + } +} + +func formatLoggedError(err error) string { + if err == nil { + return "" + } + msg := err.Error() + if strings.HasPrefix(msg, "inject ") { + return "solve failed\n" + msg + } + if idx := strings.Index(msg, ": inject "); idx >= 0 { + return "solve failed\n" + msg + } + return msg +} + +func truncateLoggedError(msg string) string { + if msg == "" { + return "" + } + lines := strings.Split(msg, "\n") + if len(lines) <= maxLoggedErrorLines { + return msg + } + omitted := len(lines) - maxLoggedErrorLines + lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) + return strings.Join(lines, "\n") +} + +func logMultilineError(msg string) { + writeErrorLog(os.Stderr, msg) +} + +func logSuccessf(format string, args ...interface{}) { + writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) +} + +func shouldColorStderr() bool { + return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) +} + +func shouldColorOutput(isTTY bool, term string) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { + return false + } + if forceColorEnabled() { + return true + } + if term == "" || term == "dumb" { + return false + } + return isTTY +} + +func forceColorEnabled() bool { + return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" +} + +func stderrIsTTY() bool { + info, err := os.Stderr.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func writeErrorLog(w io.Writer, msg string) { + line := errorSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, colorizeLines(line)) + return + } + _, _ = io.WriteString(w, line) +} + +func writeStatusLog(w io.Writer, msg string) { + line := successSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, ansiGreen+line+ansiReset) + return + } + _, _ = io.WriteString(w, line) +} + +func colorizeLines(s string) string { + if s == "" { + return "" + } + parts := strings.SplitAfter(s, "\n") + var b strings.Builder + for _, part := range parts { + if part == "" { + continue + } + b.WriteString(ansiRed) + b.WriteString(part) + b.WriteString(ansiReset) + } + return b.String() +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 2f90783..ada16d2 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -39,8 +39,8 @@ func main() { subcommands.Register(subcommands.CommandsCommand(), "") subcommands.Register(subcommands.FlagsCommand(), "") subcommands.Register(subcommands.HelpCommand(), "") - subcommands.Register(&checkCmd{}, "") subcommands.Register(&cacheCmd{}, "") + subcommands.Register(&checkCmd{}, "") subcommands.Register(&diffCmd{}, "") subcommands.Register(&genCmd{}, "") subcommands.Register(&watchCmd{}, "") @@ -60,8 +60,8 @@ func main() { "commands": true, // builtin "help": true, // builtin "flags": true, // builtin - "check": true, "cache": true, + "check": true, "diff": true, "gen": true, "serve": true, @@ -178,6 +178,10 @@ func withTiming(ctx context.Context, enabled bool) context.Context { return ctx } return wire.WithTiming(ctx, func(label string, dur time.Duration) { + if dur == 0 && strings.Contains(label, "=") { + log.Printf("timing: %s", label) + return + } log.Printf("timing: %s=%s", label, dur) }) } @@ -198,8 +202,3 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { } // logErrors logs each error with consistent formatting. -func logErrors(errs []error) { - for _, err := range errs { - log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1)) - } -} diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go new file mode 100644 index 0000000..7fe4720 --- /dev/null +++ b/cmd/wire/main_test.go @@ -0,0 +1,139 @@ +package main + +import ( + "bytes" + "fmt" + "strings" + "testing" +) + +func TestFormatLoggedErrorAddsSolveHeader(t *testing.T) { + err := testError("inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\ninject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorAddsSolveHeaderWithPositionPrefix(t *testing.T) { + err := testError("/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\n/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorLeavesNonSolveErrorsUnchanged(t *testing.T) { + err := testError("type-check failed for example.com/app/app") + got := formatLoggedError(err) + if got != err.Error() { + t.Fatalf("formatLoggedError() = %q, want %q", got, err.Error()) + } +} + +func TestTruncateLoggedErrorSummarizesLargeBlocks(t *testing.T) { + lines := make([]string, 0, maxLoggedErrorLines+3) + for i := 0; i < maxLoggedErrorLines+3; i++ { + lines = append(lines, fmt.Sprintf("line %d", i+1)) + } + got := truncateLoggedError(strings.Join(lines, "\n")) + wantLines := append(append([]string(nil), lines[:maxLoggedErrorLines]...), "... (3 additional lines omitted)") + want := strings.Join(wantLines, "\n") + if got != want { + t.Fatalf("truncateLoggedError() = %q, want %q", got, want) + } +} + +func TestShouldColorOutputForceColorOverridesTTYRequirement(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true when FORCE_COLOR is set") + } +} + +func TestShouldColorOutputNoColorWins(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "1") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false when NO_COLOR is set") + } +} + +func TestShouldColorOutputTTYFallback(t *testing.T) { + t.Setenv("FORCE_COLOR", "") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true for tty stderr") + } + if shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false for non-tty stderr without force color") + } +} + +func TestWriteErrorLogFormatsWirePrefix(t *testing.T) { + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := errorSig + "wire: type-check failed for example.com/app/app\n" + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsWholeBlockWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := ansiRed + errorSig + "wire: type-check failed for example.com/app/app\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "\n first line\n second line") + got := buf.String() + want := ansiRed + errorSig + "wire: \n" + ansiReset + + ansiRed + " first line\n" + ansiReset + + ansiRed + " second line\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteStatusLogFormatsSuccessPrefix(t *testing.T) { + var buf bytes.Buffer + writeStatusLog(&buf, "example.com/app: wrote /tmp/wire_gen.go (12ms)") + got := buf.String() + want := successSig + "wire: example.com/app: wrote /tmp/wire_gen.go (12ms)\n" + if got != want { + t.Fatalf("writeStatusLog() = %q, want %q", got, want) + } +} + +type testError string + +func (e testError) Error() string { return string(e) } diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index 779625f..45b6bc4 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -27,7 +27,6 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -102,43 +101,7 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter env := os.Environ() runGenerate := func() { - totalStart := time.Now() - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, env, packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") - return - } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - continue - } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) + _ = runGenerateCommand(ctx, wd, env, packages(f), opts, cmd.profile.timings) } root, err := moduleRoot(wd, env) @@ -328,11 +291,6 @@ func moduleRoot(wd string, env []string) (string, error) { return filepath.Dir(path), nil } -// formatDuration renders a short millisecond duration for log output. -func formatDuration(d time.Duration) string { - return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) -} - // watchWithFSNotify runs the watcher using native filesystem notifications. func watchWithFSNotify(root string, onChange func()) error { watcher, err := fsnotify.NewWatcher() diff --git a/go.mod b/go.mod index e800555..5db8855 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/goforj/wire go 1.19 require ( + github.com/fsnotify/fsnotify v1.7.0 github.com/google/go-cmp v0.6.0 github.com/google/subcommands v1.2.0 github.com/pmezard/go-difflib v1.0.0 @@ -10,7 +11,6 @@ require ( ) require ( - github.com/fsnotify/fsnotify v1.7.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect diff --git a/internal/README.md b/internal/README.md new file mode 100644 index 0000000..1134d90 --- /dev/null +++ b/internal/README.md @@ -0,0 +1,213 @@ +# Internal Package Guide + +This directory holds the implementation behind the `wire` CLI and library. +If you are new to the codebase, the two packages that matter most are: + +- `internal/wire`: parses injector/provider-set source and generates `wire_gen.go` +- `internal/loader`: loads package graphs, supports the custom loader, and manages loader-side caches + +Everything else is either small support code or repo maintenance helpers. + +## High-Level Flow + +For `wire gen`, the rough flow is: + +1. `cmd/wire` parses flags and calls `wire.Generate(...)` +2. `internal/wire` tries the output cache fast path +3. `internal/loader` loads the root graph and typed package graph +4. `internal/wire/parse.go` walks source/type information and builds Wire's internal model +5. `internal/wire/wire.go` generates formatted Go output +6. `cmd/wire` writes `wire_gen.go` + +If you are debugging behavior, start in: + +- [`wire/wire.go`](./wire/wire.go) +- [`wire/parse.go`](./wire/parse.go) +- [`loader/loader.go`](./loader/loader.go) +- [`loader/custom.go`](./loader/custom.go) + +## Package Overview + +### `internal/wire` + +This is the core implementation of Wire's analysis and code generation. + +Important responsibilities: + +- loading packages through the loader abstraction +- parsing `wire.NewSet`, `wire.Bind`, `wire.Struct`, `wire.FieldsOf`, and injector functions +- building the internal provider graph model +- validating dependency graphs +- rendering generated Go source +- managing the output cache + +Important files: + +- [`wire/wire.go`](./wire/wire.go) + Main generation entry point. Defines `Generate`, `GenerateResult`, and the final generation loop. +- [`wire/parse.go`](./wire/parse.go) + The biggest conceptual file in the repo. Defines the internal model (`ProviderSet`, `Provider`, `Value`, `Field`, etc.) and parses source/type information into that model. +- [`wire/analyze.go`](./wire/analyze.go) + Performs graph validation and dependency analysis once provider sets are parsed. +- [`wire/output_cache.go`](./wire/output_cache.go) + Output cache read/write and key construction. +- [`wire/load_debug.go`](./wire/load_debug.go) + Debug/timing summaries for loaded package graphs. +- [`wire/timing.go`](./wire/timing.go) + Timing/debug plumbing for the `wire` package. +- [`wire/loader_timing_bridge.go`](./wire/loader_timing_bridge.go) + Connects loader timing output into wire timing output. +- [`wire/errors.go`](./wire/errors.go) + Error collection and formatting helpers. +- [`wire/copyast.go`](./wire/copyast.go) + AST copying helpers used during generation. +- [`wire/loader_validation.go`](./wire/loader_validation.go) + Loader-mode-related validation helpers. + +Useful mental model: + +- `parse.go` turns package syntax/types into a Wire-specific graph model +- `analyze.go` validates that graph +- `wire.go` emits Go code from that graph + +### `internal/loader` + +This package abstracts package loading and is the main performance-sensitive layer. + +Important responsibilities: + +- choosing between the custom loader and `go/packages` fallback +- root graph loading vs typed package loading +- discovery via `go list` +- discovery cache +- loader artifact cache +- touched-package validation +- loader timings and fallback reasons + +Important files: + +- [`loader/loader.go`](./loader/loader.go) + Public loader API and shared request/result structs. This is the best entry point for understanding loader responsibilities. +- [`loader/custom.go`](./loader/custom.go) + The custom loader implementation. This is the most performance-critical file in the repo. +- [`loader/fallback.go`](./loader/fallback.go) + `go/packages` fallback implementation and fallback reason handling. +- [`loader/discovery.go`](./loader/discovery.go) + Runs `go list`, decodes package metadata, and populates the discovery cache. +- [`loader/discovery_cache.go`](./loader/discovery_cache.go) + Discovery cache storage and invalidation. +- [`loader/artifact_cache.go`](./loader/artifact_cache.go) + Loader artifact cache keying and read/write helpers. +- [`loader/mode.go`](./loader/mode.go) + Loader mode selection helpers. +- [`loader/timing.go`](./loader/timing.go) + Timing/debug plumbing for the loader package. + +Useful mental model: + +- `discovery.go` answers "what packages/files/imports exist?" +- `custom.go` answers "how do we build the package graph and type info efficiently?" +- `artifact_cache.go` is the typed-package reuse layer +- `fallback.go` is the correctness backstop + +### `internal/cachepaths` + +Small shared helper package for Wire-managed cache directories. + +Responsibilities: + +- resolve the shared cache root +- resolve specific cache directories +- support shared and specific env var overrides + +Important file: + +- [`cachepaths/cachepaths.go`](./cachepaths/cachepaths.go) + +This package exists to keep cache path policy centralized instead of duplicated across loader, output cache, and the `wire cache` command. + +## Internal Data Structures + +These are the main structs worth learning first. + +### In `internal/wire` + +- `ProviderSet` + The central Wire model. Represents a set built from `wire.NewSet(...)` or `wire.Build(...)`. +- `Provider` + A constructor source, either a function or a named struct provider. +- `IfaceBinding` + A `wire.Bind(...)` relationship. +- `Value` + A `wire.Value(...)` source. +- `Field` + A `wire.FieldsOf(...)` source. +- `InjectorArgs` / `InjectorArg` + Injector function argument modeling. + +If you understand `ProviderSet`, most of `parse.go` becomes easier to follow. + +### In `internal/loader` + +- `RootLoadRequest` / `RootLoadResult` + Used for lightweight root graph loading, primarily for cache lookup and root discovery. +- `PackageLoadRequest` / `PackageLoadResult` + Used for full typed package loading. +- `LazyLoadRequest` / `LazyLoadResult` + Used for package-targeted typed loading. +- `TouchedValidationRequest` / `TouchedValidationResult` + Used to validate whether a touched local package can stay on the custom path. +- `DiscoverySnapshot` + Captures `go list` metadata so later phases can reuse discovery without rerunning it. +- `packageMeta` + Internal metadata from `go list`. This is the custom loader's raw package description. +- `customTypedGraphLoader` + Main stateful custom loader for building typed package graphs. + +## Empty or Transitional Directories + +- `internal/cachestore` +- `internal/semanticcache` + +These directories currently exist but do not contain active implementation files. +Treat them as inactive unless they are populated later. + +## Repo Maintenance Files + +There are also a few internal scripts/data files that support repo maintenance: + +- [`alldeps`](./alldeps) + Dependency allowlist/check input. +- [`runtests.sh`](./runtests.sh) + Repo test runner used in CI/dev workflows. +- [`listdeps.sh`](./listdeps.sh) + Dependency listing helper. +- [`check_api_change.sh`](./check_api_change.sh) + API change check helper. + +## Suggested Reading Order + +If you are trying to understand the codebase quickly: + +1. [`wire/wire.go`](./wire/wire.go) +2. [`loader/loader.go`](./loader/loader.go) +3. [`loader/custom.go`](./loader/custom.go) +4. [`wire/parse.go`](./wire/parse.go) +5. [`wire/analyze.go`](./wire/analyze.go) +6. [`wire/output_cache.go`](./wire/output_cache.go) +7. [`loader/discovery.go`](./loader/discovery.go) +8. [`loader/artifact_cache.go`](./loader/artifact_cache.go) + +## Practical Notes For New Readers + +- If a behavior issue involves parsing provider sets, start in `internal/wire/parse.go`. +- If a behavior issue involves performance, cache hits, or package loading, start in `internal/loader/custom.go`. +- If a behavior issue involves "why did we skip generation?" or "why did generation return instantly?", check `internal/wire/output_cache.go`. +- If a behavior issue only appears in one environment or workspace, inspect cache path resolution in `internal/cachepaths/cachepaths.go` and discovery behavior in `internal/loader/discovery.go`. + +This repo has two real centers of complexity: + +- the Wire semantic model in `internal/wire` +- the package loading/caching machinery in `internal/loader` + +Most other internal code exists to support those two areas. diff --git a/internal/cachepaths/cachepaths.go b/internal/cachepaths/cachepaths.go new file mode 100644 index 0000000..f3e3adb --- /dev/null +++ b/internal/cachepaths/cachepaths.go @@ -0,0 +1,55 @@ +package cachepaths + +import ( + "os" + "path/filepath" + "strings" +) + +const ( + BaseDirEnv = "WIRE_CACHE_DIR" + LoaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" + DiscoveryCacheDirEnv = "WIRE_DISCOVERY_CACHE_DIR" + OutputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" +) + +var UserCacheDir = os.UserCacheDir + +func Root(env []string) (string, error) { + if dir := envValue(env, BaseDirEnv); dir != "" { + return filepath.Clean(dir), nil + } + base, err := UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire"), nil +} + +func Dir(env []string, specificEnv, name string) (string, error) { + if dir := envValue(env, specificEnv); dir != "" { + return filepath.Clean(dir), nil + } + root, err := Root(env) + if err != nil { + return "", err + } + return filepath.Join(root, name), nil +} + +func EnvValueDefault(env []string, key, fallback string) string { + if value := envValue(env, key); value != "" { + return value + } + return fallback +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key && value != "" { + return value + } + } + return "" +} diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go new file mode 100644 index 0000000..a6dfdb1 --- /dev/null +++ b/internal/loader/artifact_cache.go @@ -0,0 +1,162 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "go/token" + "go/types" + "io" + "os" + "path/filepath" + "runtime" + + "golang.org/x/tools/go/gcexportdata" + + "github.com/goforj/wire/internal/cachepaths" +) + +const ( + loaderArtifactEnv = "WIRE_LOADER_ARTIFACTS" + loaderArtifactDirEnv = cachepaths.LoaderArtifactDirEnv +) + +func loaderArtifactEnabled(env []string) bool { + return envValue(env, loaderArtifactEnv) != "0" +} + +func loaderArtifactDir(env []string) (string, error) { + return cachepaths.Dir(env, loaderArtifactDirEnv, "loader-artifacts") +} + +func loaderArtifactPath(env []string, meta *packageMeta, isLocal bool) (string, error) { + dir, err := loaderArtifactDir(env) + if err != nil { + return "", err + } + key, err := loaderArtifactKey(meta, isLocal) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".bin"), nil +} + +func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-loader-artifact-v4\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.ImportPath)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.Name)) + sum.Write([]byte{'\n'}) + if !isLocal { + sum.Write([]byte(meta.Export)) + sum.Write([]byte{'\n'}) + if meta.Export != "" { + h, err := hashFileContent(meta.Export) + if err != nil { + return "", err + } + sum.Write([]byte(h)) + sum.Write([]byte{'\n'}) + } else { + if err := hashMetaFiles(sum, metaFiles(meta)); err != nil { + return "", err + } + } + if meta.Error != nil { + sum.Write([]byte(meta.Error.Err)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil + } + if err := hashMetaFiles(sum, metaFiles(meta)); err != nil { + return "", err + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +// hashFileContent returns the hex-encoded SHA-256 of the file content. +func hashFileContent(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]), nil +} + +// hashMetaFiles writes content-based hashes for each file into sum. +func hashMetaFiles(sum io.Writer, names []string) error { + for _, name := range names { + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + h, err := hashFileContent(name) + if err != nil { + return err + } + sum.Write([]byte(h)) + sum.Write([]byte{'\n'}) + } + return nil +} + +func readLoaderArtifact(path string, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return readLoaderArtifactData(data, fset, imports, pkgPath) +} + +func readLoaderArtifactData(data []byte, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + return gcexportdata.Read(bytes.NewReader(data), fset, imports, pkgPath) +} + +func writeLoaderArtifact(path string, fset *token.FileSet, pkg *types.Package) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + return err + } + return os.WriteFile(path, out.Bytes(), 0o644) +} + +func artifactUpToDate(env []string, artifactPath string, meta *packageMeta, isLocal bool) bool { + _, err := os.Stat(artifactPath) + return err == nil +} + +func isProviderSetTypeForLoader(t types.Type) bool { + named, ok := t.(*types.Named) + if !ok { + return false + } + obj := named.Obj() + if obj == nil || obj.Pkg() == nil { + return false + } + switch obj.Pkg().Path() { + case "github.com/goforj/wire", "github.com/google/wire": + return obj.Name() == "ProviderSet" + default: + return false + } +} diff --git a/internal/loader/custom.go b/internal/loader/custom.go new file mode 100644 index 0000000..b4532af --- /dev/null +++ b/internal/loader/custom.go @@ -0,0 +1,1339 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "fmt" + "go/ast" + importerpkg "go/importer" + "go/parser" + "go/scanner" + "go/token" + "go/types" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +type unsupportedError struct { + reason string +} + +func (e unsupportedError) Error() string { return e.reason } + +type packageMeta struct { + ImportPath string + Name string + Dir string + DepOnly bool + Export string + GoFiles []string + CompiledGoFiles []string + Imports []string + ImportMap map[string]string + Module *goListModule + Error *goListError +} + +type goListModule struct { + Path string + Version string + Main bool + Dir string + GoMod string + Replace *goListModule +} + +type goListError struct { + Err string +} + +type customValidator struct { + fset *token.FileSet + meta map[string]*packageMeta + touched map[string]struct{} + packages map[string]*types.Package + importer types.Importer + loading map[string]bool +} + +type customTypedGraphLoader struct { + workspace string + ctx context.Context + env []string + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + isLocalCache map[string]bool + artifactPrefetch map[string]artifactPrefetchEntry + stats typedLoadStats +} + +type artifactPrefetchEntry struct { + path string + data []byte + err error + ok bool +} + +type typedLoadStats struct { + read time.Duration + parse time.Duration + typecheck time.Duration + localRead time.Duration + externalRead time.Duration + localParse time.Duration + externalParse time.Duration + localTypecheck time.Duration + externalTypecheck time.Duration + filesRead int + packages int + localPackages int + externalPackages int + localFilesRead int + externalFilesRead int + artifactRead time.Duration + artifactPath time.Duration + artifactDecode time.Duration + artifactImportLink time.Duration + artifactWrite time.Duration + artifactPrefetch time.Duration + rootLoad time.Duration + discovery time.Duration + artifactHits int + artifactMisses int + artifactWrites int +} + +type artifactPolicy struct { + read bool + write bool +} + +func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + if len(req.Touched) == 0 { + return &TouchedValidationResult{Backend: ModeCustom}, nil + } + meta, err := discoverTouchedMetadata(ctx, req) + if err != nil { + return nil, err + } + validator := &customValidator{ + fset: token.NewFileSet(), + meta: meta, + touched: make(map[string]struct{}, len(req.Touched)), + packages: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + for _, path := range req.Touched { + if !metadataMatchesFingerprint(path, meta, req.Local) { + return nil, unsupportedError{reason: "metadata fingerprint mismatch"} + } + validator.touched[path] = struct{}{} + } + out := make([]*packages.Package, 0, len(req.Touched)) + for _, path := range req.Touched { + pkg, err := validator.validatePackage(path) + if err != nil { + return nil, err + } + out = append(out, pkg) + } + return &TouchedValidationResult{ + Packages: out, + Backend: ModeCustom, + }, nil +} + +func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: req.NeedDeps, + SkipCompiled: true, + }) + if err != nil { + return nil, err + } + logDuration(ctx, "loader.custom.root.discovery", discoveryDuration) + pkgs := packageStubGraphFromMeta(nil, meta) + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { + roots = append(roots, pkg) + } + } + if len(roots) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return &RootLoadResult{ + Packages: roots, + Backend: ModeCustom, + Discovery: discoverySnapshotForMeta(meta, req.NeedDeps), + }, nil +} + +func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + stopProfile, profileErr := startLoaderCPUProfile(req.Env) + if profileErr != nil { + return nil, profileErr + } + if stopProfile != nil { + defer stopProfile() + } + var ( + meta map[string]*packageMeta + err error + ) + meta, discoveryDuration, err := loadCustomLazyMeta(ctx, req) + if err != nil { + return nil, err + } + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, map[string]struct{}{req.Package: {}}, []string{req.Package}, req.ParseFile, discoveryDuration, "lazy") + if err != nil { + return nil, err + } + return &LazyLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + var ( + meta map[string]*packageMeta + discoveryDuration time.Duration + err error + ) + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + meta = req.Discovery.meta + } else { + meta, discoveryDuration, err = loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: true, + }) + } + if err != nil { + return nil, err + } + rootPaths := nonDepRootImportPaths(meta) + targets := make(map[string]struct{}, len(rootPaths)) + for _, path := range rootPaths { + targets[path] = struct{}{} + } + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, targets, rootPaths, req.ParseFile, discoveryDuration, "typed") + if err != nil { + return nil, err + } + return &PackageLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func loadCustomPackagesFromMeta(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, rootPaths []string, parseFile ParseFileFunc, discoveryDuration time.Duration, mode string) ([]*packages.Package, error) { + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + if len(rootPaths) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + if fset == nil { + fset = token.NewFileSet() + } + l := newCustomTypedGraphLoader(ctx, wd, env, fset, meta, targets, parseFile, discoveryDuration) + roots, err := loadCustomRootPackages(l, rootPaths) + if err != nil { + return nil, err + } + logTypedLoadStats(ctx, mode, l.stats) + return roots, nil +} + +func loadCustomMeta(ctx context.Context, req goListRequest) (map[string]*packageMeta, time.Duration, error) { + start := time.Now() + meta, err := runGoList(ctx, req) + duration := time.Since(start) + if err != nil { + return nil, duration, err + } + if len(meta) == 0 { + return nil, duration, unsupportedError{reason: "empty go list result"} + } + return meta, duration, nil +} + +func loadCustomLazyMeta(ctx context.Context, req LazyLoadRequest) (map[string]*packageMeta, time.Duration, error) { + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + return req.Discovery.meta, 0, nil + } + return loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: []string{req.Package}, + NeedDeps: true, + }) +} + +func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { + prefetchStart := time.Now() + l.prefetchArtifacts() + l.stats.artifactPrefetch = time.Since(prefetchStart) + + rootLoadStart := time.Now() + roots := make([]*packages.Package, 0, len(paths)) + for _, path := range paths { + root, err := l.loadPackage(path) + if err != nil { + return nil, err + } + roots = append(roots, root) + } + l.stats.rootLoad = time.Since(rootLoadStart) + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return roots, nil +} + +func (v *customValidator) validatePackage(path string) (*packages.Package, error) { + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing metadata for touched package"} + } + if v.loading[path] { + return nil, unsupportedError{reason: "touched package cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + pkg := packageStub(v.fset, meta) + if appendPackageMetaError(pkg, meta) { + return pkg, nil + } + files, errs := v.parseFiles(metaFiles(meta)) + pkg.Errors = append(pkg.Errors, errs...) + if len(files) == 0 { + return pkg, nil + } + + tpkg := types.NewPackage(meta.ImportPath, meta.Name) + v.packages[meta.ImportPath] = tpkg + info := newTypesInfo() + importer := importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := resolvedImportTarget(meta, importPath) + if _, ok := v.touched[target]; ok { + if typed := v.packages[target]; typed != nil && typed.Complete() { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } + return typed, nil + } + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = checked + if len(checked.Errors) > 0 { + return nil, fmt.Errorf("touched dependency %s has errors", target) + } + if typed := v.packages[target]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "missing typed touched dependency"} + } + dep, err := v.importFromExport(target) + if err == nil { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } else { + pkg.Imports[importPath] = &packages.Package{PkgPath: target, Name: dep.Name()} + } + } + return dep, err + }) + var typeErrors []packages.Error + cfg := &types.Config{ + Importer: importer, + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + }, + } + checker := types.NewChecker(cfg, v.fset, tpkg, info) + if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + } + pkg.Syntax = files + pkg.Types = tpkg + pkg.TypesInfo = info + typeErrors = append(typeErrors, v.validateDeclaredImports(meta, files)...) + pkg.Errors = append(pkg.Errors, typeErrors...) + return pkg, nil +} + +func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, error) { + if path == "C" { + if pkg := l.packages[path]; pkg != nil { + return pkg, nil + } + tpkg := l.typesPkgs[path] + if tpkg == nil { + tpkg = types.NewPackage("C", "C") + l.typesPkgs[path] = tpkg + } + pkg := &packages.Package{ + ID: "C", + Name: "C", + PkgPath: "C", + Fset: l.fset, + Imports: make(map[string]*packages.Package), + Types: tpkg, + } + l.packages[path] = pkg + return pkg, nil + } + meta := l.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing lazy-load metadata for " + path} + } + pkg := l.packages[path] + if l.loading[path] { + if pkg != nil { + return pkg, nil + } + return nil, unsupportedError{reason: "lazy-load cycle"} + } + if pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + return pkg, nil + } + l.loading[path] = true + defer delete(l.loading, path) + l.stats.packages++ + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) + if isLocal { + l.stats.localPackages++ + } else { + l.stats.externalPackages++ + } + + if pkg == nil { + pkg = packageStub(l.fset, meta) + l.packages[path] = pkg + } + artifactPolicy := l.artifactPolicy(meta, isTarget, isLocal) + if artifactPolicy.read { + if typed, ok := l.readArtifact(path, meta, isLocal); ok { + linkStart := time.Now() + if err := l.linkPackageImports(pkg, meta); err != nil { + return nil, err + } + l.stats.artifactImportLink += time.Since(linkStart) + pkg.Types = typed + pkg.TypesInfo = nil + pkg.Syntax = nil + return pkg, nil + } + } + files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) + pkg.Errors = append(pkg.Errors, parseErrs...) + if len(files) == 0 { + appendPackageMetaError(pkg, meta) + return pkg, nil + } + + tpkg := l.typesPkgs[path] + if tpkg == nil || tpkg.Complete() || (tpkg.Scope() != nil && len(tpkg.Scope().Names()) > 0) { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + l.typesPkgs[path] = tpkg + } + needFullState := isTarget || isLocal + var info *types.Info + if needFullState { + info = newTypesInfo() + } + var typeErrors []packages.Error + cfg := &types.Config{ + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: !isLocal, + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := resolvedImportTarget(meta, importPath) + dep, err := l.loadPackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = dep + if dep.Types != nil { + return dep.Types, nil + } + if typed := l.typesPkgs[target]; typed != nil { + return typed, nil + } + if len(dep.Errors) > 0 { + return nil, dependencyImportError(dep) + } + return nil, unsupportedError{reason: "missing typed lazy-load dependency"} + }), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + }, + } + checker := types.NewChecker(cfg, l.fset, tpkg, info) + typecheckStart := time.Now() + if err := l.checkFiles(path, checker, files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + } + typecheckDuration := time.Since(typecheckStart) + l.stats.typecheck += typecheckDuration + if isLocal { + l.stats.localTypecheck += typecheckDuration + } else { + l.stats.externalTypecheck += typecheckDuration + } + if needFullState { + pkg.Syntax = files + } else { + pkg.Syntax = nil + } + pkg.Types = tpkg + pkg.TypesInfo = info + pkg.Errors = append(pkg.Errors, typeErrors...) + if artifactPolicy.write && len(pkg.Errors) == 0 { + _ = l.writeArtifact(meta, tpkg, isLocal) + } + return pkg, nil +} + +func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { + if !loaderArtifactEnabled(l.env) || isTarget { + return artifactPolicy{} + } + policy := artifactPolicy{write: true} + if !isLocal { + policy.read = true + } + return policy +} + +func (l *customTypedGraphLoader) linkPackageImports(pkg *packages.Package, meta *packageMeta) error { + for _, imp := range meta.Imports { + dep, err := l.loadPackage(resolvedImportTarget(meta, imp)) + if err != nil { + return err + } + pkg.Imports[imp] = dep + } + return nil +} + +func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { + defer func() { + if r := recover(); r != nil { + err = unsupportedError{reason: fmt.Sprintf("typecheck panic in %s: %v", path, r)} + } + }() + return checker.Files(files) +} + +func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, isLocal bool) (*types.Package, bool) { + start := time.Now() + entry, prefetched := l.artifactPrefetch[path] + artifactPath := "" + if prefetched { + artifactPath = entry.path + if entry.err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=prefetch_error err=%v", path, isLocal, entry.err) + l.stats.artifactMisses++ + return nil, false + } + } else { + pathStart := time.Now() + var err error + artifactPath, err = loaderArtifactPath(l.env, meta, isLocal) + l.stats.artifactPath += time.Since(pathStart) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + } + if isLocal { + preloadStart := time.Now() + for _, imp := range meta.Imports { + target := imp + if mapped := meta.ImportMap[imp]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=preload_dep_error dep=%s err=%v", path, isLocal, target, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + if dep != nil && dep.Types != nil { + l.typesPkgs[target] = dep.Types + } + } + l.stats.artifactImportLink += time.Since(preloadStart) + } + var tpkg *types.Package + decodeStart := time.Now() + var err error + if prefetched { + tpkg, err = readLoaderArtifactData(entry.data, l.fset, l.typesPkgs, path) + } else { + tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + } + l.stats.artifactDecode += time.Since(decodeStart) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=decode_error err=%v", path, isLocal, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + if !prefetched { + l.stats.artifactRead += time.Since(start) + } + l.stats.artifactHits++ + l.typesPkgs[path] = tpkg + return tpkg, true +} + +func (l *customTypedGraphLoader) prefetchArtifacts() { + if !loaderArtifactEnabled(l.env) { + return + } + candidates := make([]string, 0, len(l.meta)) + for path, meta := range l.meta { + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) + if l.artifactPolicy(meta, isTarget, isLocal).read { + candidates = append(candidates, path) + } + } + sort.Strings(candidates) + if len(candidates) == 0 { + return + } + type result struct { + pkg string + entry artifactPrefetchEntry + dur time.Duration + } + jobs := make(chan string, len(candidates)) + results := make(chan result, len(candidates)) + workers := 8 + if len(candidates) < workers { + workers = len(candidates) + } + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for path := range jobs { + start := time.Now() + meta := l.meta[path] + isLocal := l.isLocalPackage(path, meta) + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + entry := artifactPrefetchEntry{path: artifactPath} + if err == nil { + data, readErr := os.ReadFile(artifactPath) + if readErr != nil { + entry.err = readErr + } else { + entry.data = data + entry.ok = true + } + } else { + entry.err = err + } + results <- result{pkg: path, entry: entry, dur: time.Since(start)} + } + }() + } + for _, path := range candidates { + jobs <- path + } + close(jobs) + wg.Wait() + close(results) + for res := range results { + l.artifactPrefetch[res.pkg] = res.entry + l.stats.artifactRead += res.dur + pathStart := time.Now() + _ = res.entry.path + l.stats.artifactPath += time.Since(pathStart) + } +} + +func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Package, isLocal bool) error { + start := time.Now() + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + if err != nil { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=path_error err=%v", meta.ImportPath, isLocal, err) + l.stats.artifactWrite += time.Since(start) + return err + } + if artifactUpToDate(l.env, artifactPath, meta, isLocal) { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=up_to_date", meta.ImportPath, isLocal) + l.stats.artifactWrite += time.Since(start) + return nil + } + writeErr := writeLoaderArtifact(artifactPath, l.fset, pkg) + l.stats.artifactWrite += time.Since(start) + if writeErr == nil { + l.stats.artifactWrites++ + debugf(l.ctx, "loader.artifact.write_ok pkg=%s local=%t path=%s", meta.ImportPath, isLocal, artifactPath) + } else { + debugf(l.ctx, "loader.artifact.write_fail pkg=%s local=%t err=%v", meta.ImportPath, isLocal, writeErr) + } + if writeErr != nil { + return writeErr + } + return nil +} + +func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *packageMeta) bool { + if local, ok := l.isLocalCache[importPath]; ok { + return local + } + local := isLocalSourcePackage(l.workspace, meta) + l.isLocalCache[importPath] = local + return local +} + +func (v *customValidator) importFromExport(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + if v.importer != nil { + if imported, err := v.importer.Import(path); err == nil { + v.packages[path] = imported + return imported, nil + } + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing dependency metadata"} + } + if meta.Export == "" { + return v.loadDependencyFromSource(path) + } + exportPath := meta.Export + if !filepath.IsAbs(exportPath) { + exportPath = filepath.Join(meta.Dir, exportPath) + } + f, err := os.Open(exportPath) + if err != nil { + return nil, unsupportedError{reason: "open export data"} + } + defer f.Close() + r, err := gcexportdata.NewReader(f) + if err != nil { + return nil, unsupportedError{reason: "read export data"} + } + view := make(map[string]*types.Package, len(v.packages)) + for pkgPath, pkg := range v.packages { + view[pkgPath] = pkg + } + tpkg, err := gcexportdata.Read(r, v.fset, view, path) + if err != nil { + return v.loadDependencyFromSource(path) + } + v.packages[path] = tpkg + return tpkg, nil +} + +func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing source dependency metadata"} + } + if v.loading[path] { + if typed := v.packages[path]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "dependency cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + + tpkg := v.packages[path] + if tpkg == nil { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + v.packages[path] = tpkg + } + files, errs := v.parseFiles(metaFiles(meta)) + if len(errs) > 0 { + return nil, unsupportedError{reason: "dependency parse error"} + } + info := newTypesInfo() + cfg := &types.Config{ + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := resolvedImportTarget(meta, importPath) + if _, ok := v.touched[target]; ok { + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + if len(checked.Errors) > 0 { + return nil, unsupportedError{reason: "touched dependency has validation errors"} + } + return v.packages[target], nil + } + return v.importFromExport(target) + }), + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: true, + } + if err := types.NewChecker(cfg, v.fset, tpkg, info).Files(files); err != nil { + return nil, unsupportedError{reason: "dependency typecheck error"} + } + return tpkg, nil +} + +func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + src, err := os.ReadFile(name) + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + f, err := parseGoSourceFile(v.fset, nil, name, src) + if err != nil { + errs = appendParseErrors(errs, name, err) + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + readStart := time.Now() + src, err := os.ReadFile(name) + readDuration := time.Since(readStart) + l.stats.read += readDuration + l.stats.filesRead++ + if isLocal { + l.stats.localRead += readDuration + l.stats.localFilesRead++ + } else { + l.stats.externalRead += readDuration + l.stats.externalFilesRead++ + } + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + var f *ast.File + parseStart := time.Now() + f, err = parseGoSourceFile(l.fset, l.parseFile, name, src) + parseDuration := time.Since(parseStart) + l.stats.parse += parseDuration + if isLocal { + l.stats.localParse += parseDuration + } else { + l.stats.externalParse += parseDuration + } + if err != nil { + errs = appendParseErrors(errs, name, err) + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func toPackagesError(fset *token.FileSet, err error) packages.Error { + switch typed := err.(type) { + case packages.Error: + return typed + case types.Error: + return packages.Error{ + Pos: typed.Fset.Position(typed.Pos).String(), + Msg: typed.Msg, + Kind: packages.TypeError, + } + default: + pos := "-" + if fset != nil { + if te, ok := err.(interface{ Pos() token.Pos }); ok { + pos = fset.Position(te.Pos()).String() + } + } + return packages.Error{Pos: pos, Msg: err.Error(), Kind: packages.UnknownError} + } +} + +func dependencyImportError(pkg *packages.Package) error { + if pkg == nil { + return unsupportedError{reason: "lazy-load dependency has errors"} + } + if pkg.Name == "" { + return fmt.Errorf("invalid package name: %q", pkg.Name) + } + for _, err := range pkg.Errors { + if strings.TrimSpace(err.Msg) == "" { + continue + } + return fmt.Errorf("%s", err.Msg) + } + return unsupportedError{reason: "lazy-load dependency has errors"} +} + +type importerFunc func(path string) (*types.Package, error) + +func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } + +func (v *customValidator) validateDeclaredImports(meta *packageMeta, files []*ast.File) []packages.Error { + var errs []packages.Error + for _, file := range files { + used := usedImportsInFile(file) + for _, spec := range file.Imports { + if spec == nil || spec.Path == nil { + continue + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + continue + } + target := resolvedImportTarget(meta, path) + name := importName(spec) + if name != "_" && name != "." { + if _, ok := used[name]; !ok { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("%q imported and not used", path), + Kind: packages.TypeError, + }) + continue + } + } + if _, err := v.importFromExport(target); err != nil { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("could not import %s", path), + Kind: packages.TypeError, + }) + } + } + } + return errs +} + +func usedImportsInFile(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + ast.Inspect(file, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + return used +} + +func importName(spec *ast.ImportSpec) string { + if spec == nil || spec.Path == nil { + return "" + } + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + return "" + } + if slash := strings.LastIndex(path, "/"); slash >= 0 { + path = path[slash+1:] + } + return path +} + +func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) (map[string]*packageMeta, error) { + metas, _, err := loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Touched, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + for _, touched := range req.Touched { + if _, ok := metas[touched]; !ok { + return nil, unsupportedError{reason: "missing touched package in metadata"} + } + } + return metas, nil +} + +func normalizeImports(imports []string, importMap map[string]string) []string { + if len(imports) == 0 { + return nil + } + out := make([]string, 0, len(imports)) + for _, imp := range imports { + if mapped := importMap[imp]; mapped != "" { + out = append(out, mapped) + continue + } + out = append(out, imp) + } + sort.Strings(out) + return out +} + +func metaFiles(meta *packageMeta) []string { + if meta == nil { + return nil + } + if len(meta.CompiledGoFiles) > 0 { + return meta.CompiledGoFiles + } + return meta.GoFiles +} + +func discoverySnapshotForMeta(meta map[string]*packageMeta, complete bool) *DiscoverySnapshot { + if !complete || len(meta) == 0 { + return nil + } + return &DiscoverySnapshot{meta: meta} +} + +func isWorkspacePackage(workspaceRoot, dir string) bool { + if workspaceRoot == "" || dir == "" { + return false + } + if dir == workspaceRoot { + return true + } + rel, err := filepath.Rel(workspaceRoot, dir) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +func isLocalSourcePackage(workspaceRoot string, meta *packageMeta) bool { + if meta == nil { + return false + } + if isWorkspacePackage(workspaceRoot, meta.Dir) { + return true + } + mod := localSourceModule(meta.Module) + if mod == nil { + return false + } + if mod.Main { + return true + } + return canonicalLoaderPath(mod.Dir) == canonicalLoaderPath(meta.Dir) || isWorkspacePackage(canonicalLoaderPath(mod.Dir), meta.Dir) +} + +func localSourceModule(mod *goListModule) *goListModule { + if mod == nil { + return nil + } + if mod.Replace != nil { + if local := localSourceModule(mod.Replace); local != nil { + return local + } + } + if mod.Main && mod.Dir != "" { + return mod + } + if mod.Replace != nil && mod.Replace.Dir != "" { + return mod.Replace + } + return nil +} + +func detectModuleRoot(start string) string { + start = canonicalLoaderPath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalLoaderPath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func startLoaderCPUProfile(env []string) (func(), error) { + path := envValue(env, "WIRE_LOADER_CPU_PROFILE") + if strings.TrimSpace(path) == "" { + return nil, nil + } + f, err := os.Create(path) + if err != nil { + return nil, err + } + if err := pprof.StartCPUProfile(f); err != nil { + _ = f.Close() + return nil, err + } + return func() { + pprof.StopCPUProfile() + _ = f.Close() + }, nil +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} + +func newCustomTypedGraphLoader(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, parseFile ParseFileFunc, discoveryDuration time.Duration) *customTypedGraphLoader { + return &customTypedGraphLoader{ + workspace: detectModuleRoot(wd), + ctx: ctx, + env: append([]string(nil), env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: parseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } +} + +func packageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + if meta == nil { + return nil + } + return &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } +} + +func packageStubGraphFromMeta(fset *token.FileSet, meta map[string]*packageMeta) map[string]*packages.Package { + pkgs := make(map[string]*packages.Package, len(meta)) + for path, m := range meta { + pkgs[path] = packageStub(fset, m) + appendPackageMetaError(pkgs[path], m) + } + for path, m := range meta { + pkg := pkgs[path] + for _, imp := range m.Imports { + target := resolvedImportTarget(m, imp) + if dep := pkgs[target]; dep != nil { + pkg.Imports[imp] = dep + } + } + } + return pkgs +} + +func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { + if pkg == nil || meta == nil || meta.Error == nil || strings.TrimSpace(meta.Error.Err) == "" { + return false + } + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + return true +} + +func resolvedImportTarget(meta *packageMeta, importPath string) string { + if meta == nil { + return importPath + } + if mapped := meta.ImportMap[importPath]; mapped != "" { + return mapped + } + return importPath +} + +func newTypesInfo() *types.Info { + return &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } +} + +func parseGoSourceFile(fset *token.FileSet, parseFile ParseFileFunc, name string, src []byte) (*ast.File, error) { + if parseFile != nil { + return parseFile(fset, name, src) + } + return parser.ParseFile(fset, name, src, parser.AllErrors|parser.ParseComments) +} + +func appendParseErrors(errs []packages.Error, name string, err error) []packages.Error { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + return errs +} + +func nonDepRootImportPaths(meta map[string]*packageMeta) []string { + paths := make([]string, 0, len(meta)) + for _, m := range meta { + if m == nil || m.DepOnly { + continue + } + paths = append(paths, m.ImportPath) + } + sort.Strings(paths) + return paths +} + +func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { + prefix := "loader.custom." + mode + logDuration(ctx, prefix+".read_files.cumulative", stats.read) + logDuration(ctx, prefix+".parse_files.cumulative", stats.parse) + logDuration(ctx, prefix+".typecheck.cumulative", stats.typecheck) + logDuration(ctx, prefix+".read_files.local.cumulative", stats.localRead) + logDuration(ctx, prefix+".read_files.external.cumulative", stats.externalRead) + logDuration(ctx, prefix+".parse_files.local.cumulative", stats.localParse) + logDuration(ctx, prefix+".parse_files.external.cumulative", stats.externalParse) + logDuration(ctx, prefix+".typecheck.local.cumulative", stats.localTypecheck) + logDuration(ctx, prefix+".typecheck.external.cumulative", stats.externalTypecheck) + logDuration(ctx, prefix+".artifact_read", stats.artifactRead) + logDuration(ctx, prefix+".artifact_path", stats.artifactPath) + logDuration(ctx, prefix+".artifact_decode", stats.artifactDecode) + logDuration(ctx, prefix+".artifact_import_link", stats.artifactImportLink) + logDuration(ctx, prefix+".artifact_write", stats.artifactWrite) + logDuration(ctx, prefix+".artifact_prefetch.wall", stats.artifactPrefetch) + logDuration(ctx, prefix+".root_load.wall", stats.rootLoad) + logDuration(ctx, prefix+".discovery.wall", stats.discovery) + logInt(ctx, prefix+".artifact_hits", stats.artifactHits) + logInt(ctx, prefix+".artifact_misses", stats.artifactMisses) + logInt(ctx, prefix+".artifact_writes", stats.artifactWrites) +} + +func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + return packageStub(fset, meta) +} + +func metadataMatchesFingerprint(pkgPath string, meta map[string]*packageMeta, local []LocalPackageFingerprint) bool { + for _, fp := range local { + if fp.PkgPath != pkgPath { + continue + } + pm := meta[pkgPath] + if pm == nil { + return false + } + want := append([]string(nil), fp.Files...) + got := append([]string(nil), metaFiles(pm)...) + sort.Strings(want) + sort.Strings(got) + if len(want) != len(got) { + return false + } + for i := range want { + if filepath.Clean(want[i]) != filepath.Clean(got[i]) { + return false + } + } + return true + } + return true +} diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go new file mode 100644 index 0000000..bccfd93 --- /dev/null +++ b/internal/loader/discovery.go @@ -0,0 +1,134 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +type goListRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool + SkipCompiled bool +} + +func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + cacheReadStart := time.Now() + if cached, ok := readDiscoveryCache(req); ok { + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + logDuration(ctx, "loader.discovery.golist.wall", 0) + logDuration(ctx, "loader.discovery.decode.wall", 0) + logDuration(ctx, "loader.discovery.canonicalize.wall", 0) + logDuration(ctx, "loader.discovery.cache_build.wall", 0) + logDuration(ctx, "loader.discovery.cache_write.wall", 0) + return cached, nil + } + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + args := []string{"list", "-json", "-e", "-export"} + if !req.SkipCompiled { + args = append(args, "-compiled") + } + if req.NeedDeps { + args = append(args, "-deps") + } + if req.Tags != "" { + args = append(args, "-tags=wireinject "+req.Tags) + } else { + args = append(args, "-tags=wireinject") + } + args = append(args, "--") + args = append(args, req.Patterns...) + + cmd := exec.CommandContext(ctx, "go", args...) + cmd.Dir = req.WD + if len(req.Env) > 0 { + cmd.Env = req.Env + } else { + cmd.Env = os.Environ() + } + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + goListStart := time.Now() + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("go list: %w: %s", err, stderr.String()) + } + goListDuration := time.Since(goListStart) + dec := json.NewDecoder(&stdout) + out := make(map[string]*packageMeta) + var decodeDuration time.Duration + var canonicalizeDuration time.Duration + for { + var meta packageMeta + decodeStart := time.Now() + if err := dec.Decode(&meta); err != nil { + decodeDuration += time.Since(decodeStart) + if err == io.EOF { + break + } + return nil, err + } + decodeDuration += time.Since(decodeStart) + if meta.ImportPath == "" { + continue + } + canonicalizeStart := time.Now() + meta.Dir = canonicalLoaderPath(meta.Dir) + for i, name := range meta.GoFiles { + if !filepath.IsAbs(name) { + meta.GoFiles[i] = filepath.Join(meta.Dir, name) + } + } + for i, name := range meta.CompiledGoFiles { + if !filepath.IsAbs(name) { + meta.CompiledGoFiles[i] = filepath.Join(meta.Dir, name) + } + } + if meta.Export != "" && !filepath.IsAbs(meta.Export) { + meta.Export = filepath.Join(meta.Dir, meta.Export) + } + meta.Imports = normalizeImports(meta.Imports, meta.ImportMap) + canonicalizeDuration += time.Since(canonicalizeStart) + copyMeta := meta + out[meta.ImportPath] = ©Meta + } + cacheBuildStart := time.Now() + entry, err := buildDiscoveryCacheEntry(req, out) + cacheBuildDuration := time.Since(cacheBuildStart) + if err == nil && entry != nil { + cacheWriteStart := time.Now() + _ = saveDiscoveryCacheEntry(req, entry) + logDuration(ctx, "loader.discovery.cache_write.wall", time.Since(cacheWriteStart)) + } else { + logDuration(ctx, "loader.discovery.cache_write.wall", 0) + } + logDuration(ctx, "loader.discovery.golist.wall", goListDuration) + logDuration(ctx, "loader.discovery.decode.wall", decodeDuration) + logDuration(ctx, "loader.discovery.canonicalize.wall", canonicalizeDuration) + logDuration(ctx, "loader.discovery.cache_build.wall", cacheBuildDuration) + return out, nil +} diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go new file mode 100644 index 0000000..3381041 --- /dev/null +++ b/internal/loader/discovery_cache.go @@ -0,0 +1,302 @@ +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "go/parser" + "go/token" + "os" + "path/filepath" + "runtime" + "sort" + + "github.com/goforj/wire/internal/cachepaths" +) + +type discoveryCacheEntry struct { + Version int + Meta map[string]*packageMeta + Global []discoveryFileMeta + LocalPkgs []discoveryLocalPackage +} + +type discoveryLocalPackage struct { + ImportPath string + Dir string + DirMeta discoveryDirMeta + Files []discoveryFileFingerprint +} + +type discoveryFileMeta struct { + Path string + Size int64 + ModTime int64 // deprecated: kept for gob compat, not used for matching + ContentHash string // sha256 of file content + IsDir bool +} + +type discoveryDirMeta struct { + Path string + Entries []string +} + +type discoveryFileFingerprint struct { + Path string + Hash string +} + +const discoveryCacheVersion = 4 + +func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { + entry, err := loadDiscoveryCacheEntry(req) + if err != nil || entry == nil { + return nil, false + } + if !validateDiscoveryCacheEntry(entry) { + return nil, false + } + return entry.Meta, true +} + +func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { + workspace := detectModuleRoot(req.WD) + entry := &discoveryCacheEntry{ + Version: discoveryCacheVersion, + Meta: meta, + } + global := []string{ + filepath.Join(workspace, "go.mod"), + filepath.Join(workspace, "go.sum"), + filepath.Join(workspace, "go.work"), + filepath.Join(workspace, "go.work.sum"), + } + for _, name := range global { + if fm, ok := statDiscoveryFile(name); ok { + entry.Global = append(entry.Global, fm) + } + } + locals := make([]discoveryLocalPackage, 0) + for _, pkg := range meta { + if pkg == nil || !isLocalSourcePackage(workspace, pkg) { + continue + } + lp := discoveryLocalPackage{ + ImportPath: pkg.ImportPath, + Dir: pkg.Dir, + } + if fm, ok := statDiscoveryDir(pkg.Dir); ok { + lp.DirMeta = fm + } + for _, name := range metaFiles(pkg) { + if fm, ok := fingerprintDiscoveryFile(name); ok { + lp.Files = append(lp.Files, fm) + } + } + sort.Slice(lp.Files, func(i, j int) bool { return lp.Files[i].Path < lp.Files[j].Path }) + locals = append(locals, lp) + } + sort.Slice(locals, func(i, j int) bool { return locals[i].ImportPath < locals[j].ImportPath }) + entry.LocalPkgs = locals + return entry, nil +} + +func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { + if entry == nil || entry.Version != discoveryCacheVersion { + return false + } + for _, fm := range entry.Global { + if !matchesDiscoveryFile(fm) { + return false + } + } + for _, lp := range entry.LocalPkgs { + if !matchesDiscoveryDir(lp.DirMeta) { + return false + } + for _, fm := range lp.Files { + if !matchesDiscoveryFingerprint(fm) { + return false + } + } + } + return true +} + +const discoveryCacheDirEnv = cachepaths.DiscoveryCacheDirEnv + +func discoveryCachePath(req goListRequest) (string, error) { + dir, err := cachepaths.Dir(req.Env, discoveryCacheDirEnv, "discovery-cache") + if err != nil { + return "", err + } + sumReq := struct { + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + SkipCompiled bool + Go string + }{ + Version: discoveryCacheVersion, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + SkipCompiled: req.SkipCompiled, + Go: runtime.Version(), + } + key, err := hashGob(sumReq) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func loadDiscoveryCacheEntry(req goListRequest) (*discoveryCacheEntry, error) { + path, err := discoveryCachePath(req) + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var entry discoveryCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, err + } + return &entry, nil +} + +func saveDiscoveryCacheEntry(req goListRequest, entry *discoveryCacheEntry) error { + path, err := discoveryCachePath(req) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func statDiscoveryFile(path string) (discoveryFileMeta, bool) { + info, err := os.Stat(path) + if err != nil { + return discoveryFileMeta{}, false + } + h := "" + if !info.IsDir() { + var err error + h, err = hashFileContent(path) + if err != nil { + return discoveryFileMeta{}, false + } + } + return discoveryFileMeta{ + Path: canonicalLoaderPath(path), + Size: info.Size(), + ContentHash: h, + IsDir: info.IsDir(), + }, true +} + +func matchesDiscoveryFile(fm discoveryFileMeta) bool { + cur, ok := statDiscoveryFile(fm.Path) + if !ok { + return false + } + return cur.ContentHash == fm.ContentHash && cur.IsDir == fm.IsDir +} + +func statDiscoveryDir(path string) (discoveryDirMeta, bool) { + entries, err := os.ReadDir(path) + if err != nil { + return discoveryDirMeta{}, false + } + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name()) + } + sort.Strings(names) + return discoveryDirMeta{ + Path: canonicalLoaderPath(path), + Entries: names, + }, true +} + +func matchesDiscoveryDir(dm discoveryDirMeta) bool { + cur, ok := statDiscoveryDir(dm.Path) + if !ok { + return false + } + if len(cur.Entries) != len(dm.Entries) { + return false + } + for i := range cur.Entries { + if cur.Entries[i] != dm.Entries[i] { + return false + } + } + return true +} + +func fingerprintDiscoveryFile(path string) (discoveryFileFingerprint, bool) { + src, err := os.ReadFile(path) + if err != nil { + return discoveryFileFingerprint{}, false + } + sum := sha256.New() + sum.Write([]byte(filepath.Base(path))) + sum.Write([]byte{0}) + file, err := parser.ParseFile(token.NewFileSet(), path, src, parser.ImportsOnly|parser.ParseComments) + if err != nil { + sum.Write(src) + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true + } + if offset := int(file.Package) - 1; offset > 0 && offset <= len(src) { + sum.Write(src[:offset]) + } + sum.Write([]byte(file.Name.Name)) + sum.Write([]byte{0}) + for _, imp := range file.Imports { + if imp.Name != nil { + sum.Write([]byte(imp.Name.Name)) + } + sum.Write([]byte{0}) + sum.Write([]byte(imp.Path.Value)) + sum.Write([]byte{0}) + } + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true +} + +func matchesDiscoveryFingerprint(fp discoveryFileFingerprint) bool { + cur, ok := fingerprintDiscoveryFile(fp.Path) + if !ok { + return false + } + return cur.Hash == fp.Hash +} + +func hashGob(v interface{}) (string, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(v); err != nil { + return "", err + } + sum := sha256.Sum256(buf.Bytes()) + return hex.EncodeToString(sum[:]), nil +} diff --git a/internal/loader/discovery_cache_test.go b/internal/loader/discovery_cache_test.go new file mode 100644 index 0000000..953824f --- /dev/null +++ b/internal/loader/discovery_cache_test.go @@ -0,0 +1,126 @@ +package loader + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiscoveryFingerprintIgnoresBodyOnlyEdits(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("before") +} +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("after") +} +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after body edit", path) + } + if fpBefore.Hash != fpAfter.Hash { + t.Fatalf("body-only edit changed fingerprint: %s != %s", fpBefore.Hash, fpAfter.Hash) + } +} + +func TestDiscoveryFingerprintDetectsImportChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "strings" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after import edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("import edit did not change fingerprint") + } +} + +func TestDiscoveryFingerprintDetectsHeaderBuildTagChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `//go:build linux + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `//go:build darwin + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after header edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("build tag edit did not change fingerprint") + } +} + +func TestDiscoveryDirDetectsFileSetChange(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + before, ok := statDiscoveryDir(dir) + if !ok { + t.Fatalf("statDiscoveryDir(%q) failed", dir) + } + if err := os.WriteFile(filepath.Join(dir, "b.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + if matchesDiscoveryDir(before) { + t.Fatalf("directory metadata did not detect added file") + } +} diff --git a/internal/loader/fallback.go b/internal/loader/fallback.go new file mode 100644 index 0000000..860bd50 --- /dev/null +++ b/internal/loader/fallback.go @@ -0,0 +1,203 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "errors" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type defaultLoader struct{} + +func fallbackReasonDetail(mode Mode, detail string) (FallbackReason, string) { + switch mode { + case ModeFallback: + return FallbackReasonForcedFallback, "" + default: + return FallbackReasonCustomUnsupported, detail + } +} + +func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &PackageLoadResult{ + Backend: ModeFallback, + } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadRootGraph(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := loadRootGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &RootLoadResult{ + Backend: ModeFallback, + } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, unsupported.reason) + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if req.NeedDeps { + cfg.Mode |= packages.NeedDeps + } + if req.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadTypedPackageGraph(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadTypedPackageGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &LazyLoadResult{ + Backend: ModeFallback, + } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, "pattern="+req.Package) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) ValidateTouchedPackages(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := validateTouchedPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + return validateTouchedPackagesFallback(ctx, req, unsupported.reason) +} + +func validateTouchedPackagesFallback(ctx context.Context, req TouchedValidationRequest, detail string) (*TouchedValidationResult, error) { + result := &TouchedValidationResult{ + Backend: ModeFallback, + } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, detail) + if len(req.Touched) == 0 { + return result, nil + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: token.NewFileSet(), + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, req.Touched...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} diff --git a/internal/loader/loader.go b/internal/loader/loader.go new file mode 100644 index 0000000..a507758 --- /dev/null +++ b/internal/loader/loader.go @@ -0,0 +1,136 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "go/ast" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type Mode string + +const ( + ModeAuto Mode = "auto" + ModeCustom Mode = "custom" + ModeFallback Mode = "fallback" +) + +type FallbackReason string + +const ( + FallbackReasonNone FallbackReason = "" + FallbackReasonForcedFallback FallbackReason = "forced_fallback" + FallbackReasonCustomNotImplemented FallbackReason = "custom_not_implemented" + FallbackReasonCustomUnsupported FallbackReason = "custom_unsupported" +) + +type LocalPackageFingerprint struct { + PkgPath string + ContentHash string + ShapeHash string + Files []string +} + +type DiscoverySnapshot struct { + meta map[string]*packageMeta +} + +type TouchedValidationRequest struct { + WD string + Env []string + Tags string + Touched []string + Local []LocalPackageFingerprint + Mode Mode +} + +type TouchedValidationResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type RootLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool + Mode Mode + Fset *token.FileSet +} + +type RootLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string + Discovery *DiscoverySnapshot +} + +type PackageLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc + Discovery *DiscoverySnapshot +} + +type PackageLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type ParseFileFunc func(*token.FileSet, string, []byte) (*ast.File, error) + +type LazyLoadRequest struct { + WD string + Env []string + Tags string + Package string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc + Discovery *DiscoverySnapshot +} + +type LazyLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type Loader interface { + LoadPackages(context.Context, PackageLoadRequest) (*PackageLoadResult, error) + LoadRootGraph(context.Context, RootLoadRequest) (*RootLoadResult, error) + LoadTypedPackageGraph(context.Context, LazyLoadRequest) (*LazyLoadResult, error) + ValidateTouchedPackages(context.Context, TouchedValidationRequest) (*TouchedValidationResult, error) +} + +func New() Loader { + return defaultLoader{} +} diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go new file mode 100644 index 0000000..05cfaa7 --- /dev/null +++ b/internal/loader/loader_test.go @@ -0,0 +1,3288 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "archive/zip" + "bytes" + "context" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "os" + "os/exec" + "path/filepath" + "sort" + "strconv" + "strings" + "testing" + "time" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +func TestModeFromEnvDefaultAuto(t *testing.T) { + if got := ModeFromEnv(nil); got != ModeAuto { + t.Fatalf("ModeFromEnv(nil) = %q, want %q", got, ModeAuto) + } +} + +func TestModeFromEnvUsesLastMatchingValue(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=fallback", + "OTHER=value", + "WIRE_LOADER_MODE=custom", + } + if got := ModeFromEnv(env); got != ModeCustom { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeCustom) + } +} + +func TestModeFromEnvIgnoresInvalidValues(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=invalid", + } + if got := ModeFromEnv(env); got != ModeAuto { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeAuto) + } +} + +func TestFallbackLoaderReasonFromMode(t *testing.T) { + l := New() + + gotAuto, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if gotAuto.Backend != ModeCustom { + t.Fatalf("auto backend = %q, want %q", gotAuto.Backend, ModeCustom) + } + if gotAuto.FallbackReason != FallbackReasonNone { + t.Fatalf("auto fallback reason = %q, want none", gotAuto.FallbackReason) + } + + gotForced, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + if gotForced.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("forced fallback reason = %q, want %q", gotForced.FallbackReason, FallbackReasonForcedFallback) + } +} + +func TestCustomTouchedValidationSuccess(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nimport \"fmt\"\n\nfunc Use() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) != 0 { + t.Fatalf("unexpected package errors: %+v", got.Packages[0].Errors) + } +} + +func TestValidateTouchedPackagesAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } +} + +func TestCustomTouchedValidationTypeError(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) == 0 { + t.Fatal("expected type-check errors") + } +} + +func TestValidateTouchedPackagesCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Use() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesAutoReportsFallbackDetail(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Local: []LocalPackageFingerprint{ + { + PkgPath: "example.com/app/a", + ContentHash: "wrong", + ShapeHash: "wrong", + Files: []string{filepath.Join(root, "a", "a.go")}, + }, + }, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + switch got.Backend { + case ModeCustom: + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want empty for custom backend", got.FallbackReason) + } + if got.FallbackDetail != "" { + t.Fatalf("fallback detail = %q, want empty for custom backend", got.FallbackDetail) + } + case ModeFallback: + if got.FallbackReason != FallbackReasonCustomUnsupported { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) + } + if got.FallbackDetail != "metadata fingerprint mismatch" { + t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + } + default: + t.Fatalf("backend = %q, want %q or %q", got.Backend, ModeCustom, ModeFallback) + } +} + +func TestLoadRootGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"fmt\"\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("LoadRootGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) == 0 { + t.Fatal("expected loaded root packages") + } +} + +func TestLoadRootGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if got.Packages[0].Imports["example.com/app/dep"] == nil { + t.Fatal("expected custom root graph to wire local import dependency") + } +} + +func TestLoadRootGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeAuto, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestMetaFilesFallsBackToGoFiles(t *testing.T) { + meta := &packageMeta{ + GoFiles: []string{"a.go", "b.go"}, + } + got := metaFiles(meta) + if len(got) != 2 || got[0] != "a.go" || got[1] != "b.go" { + t.Fatalf("metaFiles(go-only) = %v, want GoFiles fallback", got) + } + + meta.CompiledGoFiles = []string{"c.go"} + got = metaFiles(meta) + if len(got) != 1 || got[0] != "c.go" { + t.Fatalf("metaFiles(compiled) = %v, want CompiledGoFiles", got) + } +} + +func TestExportDataPairings(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "lib.go", "package lib\n\ntype T int\n", 0) + if err != nil { + t.Fatalf("ParseFile() error = %v", err) + } + pkg, err := new(types.Config).Check("lib", fset, []*ast.File{file}, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + t.Run("gcexportdata write/read direct", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), make(map[string]*types.Package), pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + if got.Scope().Lookup("T") == nil { + t.Fatal("reimported package missing T") + } + }) + + t.Run("gcexportdata write with newreader fails", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + if _, err := gcexportdata.NewReader(bytes.NewReader(out.Bytes())); err == nil { + t.Fatal("gcexportdata.NewReader() unexpectedly succeeded on direct gcexportdata.Write output") + } + }) +} + +func TestExportDataRoundTripWithImports(t *testing.T) { + fset := token.NewFileSet() + depPkg, err := new(types.Config).Check("example.com/dep", fset, []*ast.File{ + mustParseFile(t, fset, "dep.go", `package dep + +type T int +`), + }, nil) + if err != nil { + t.Fatalf("types.Check(dep) error = %v", err) + } + pkg, err := (&types.Config{ + Importer: importerFuncForTest(func(path string) (*types.Package, error) { + if path == "example.com/dep" { + return depPkg, nil + } + if path == "unsafe" { + return types.Unsafe, nil + } + return nil, nil + }), + }).Check("example.com/lib", fset, []*ast.File{ + mustParseFile(t, fset, "lib.go", `package lib + +import "example.com/dep" + +type T struct { + S dep.T +} +`), + }, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + imports := make(map[string]*types.Package) + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), imports, pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + obj := got.Scope().Lookup("T") + if obj == nil { + t.Fatal("reimported package missing T") + } + named, ok := obj.Type().(*types.Named) + if !ok { + t.Fatalf("T type = %T, want *types.Named", obj.Type()) + } + field := named.Underlying().(*types.Struct).Field(0) + if field.Type().String() != "example.com/dep.T" { + t.Fatalf("field type = %q, want %q", field.Type().String(), "example.com/dep.T") + } + depImport := imports["example.com/dep"] + if depImport == nil { + t.Fatal("imports map missing dep") + } + if depImport.Scope().Lookup("T") == nil { + t.Fatal("dep import missing T after import") + } +} + +func TestLoadTypedPackageGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nfunc Value() int { return 42 }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if parseCalls == 0 { + t.Fatal("expected ParseFile hook to be used") + } +} + +func TestLoadTypedPackageGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + rootPkg := got.Packages[0] + if rootPkg.Types == nil || rootPkg.TypesInfo == nil || len(rootPkg.Syntax) == 0 { + t.Fatalf("root package missing typed syntax: %+v", rootPkg) + } + depPkg := rootPkg.Imports["example.com/app/dep"] + if depPkg == nil || depPkg.Types == nil || len(depPkg.Syntax) == 0 { + t.Fatalf("dep package missing typed syntax: %+v", depPkg) + } + if parseCalls < 2 { + t.Fatalf("parseCalls = %d, want at least 2", parseCalls) + } +} + +func TestLoadTypedPackageGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeAuto, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestLoadTypedPackageGraphCustomKeepsExternalPackagesLight(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + fmtPkg := rootPkg.Imports["fmt"] + if fmtPkg == nil { + t.Fatal("expected fmt import package") + } + if fmtPkg.Types == nil { + t.Fatalf("fmt package missing types: %+v", fmtPkg) + } + if fmtPkg.TypesInfo != nil { + t.Fatalf("fmt package TypesInfo should be nil, got %+v", fmtPkg.TypesInfo) + } + if len(fmtPkg.Syntax) != 0 { + t.Fatalf("fmt package Syntax len = %d, want 0", len(fmtPkg.Syntax)) + } +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCache(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + if rootPkg.Imports["fmt"] == nil { + t.Fatal("expected fmt import package") + } + return parseCalls + } + + first := run() + entries, err := os.ReadDir(artifactDir) + if err != nil { + t.Fatalf("ReadDir(%q) error = %v", artifactDir, err) + } + if len(entries) == 0 { + t.Fatal("expected artifact cache files after first run") + } + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() []string { + var labels []string + ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { + labels = append(labels, label) + }) + l := New() + _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return labels + } + + _ = run() + second := run() + if !hasPrefixLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run labels missing artifact hit count: %v", second) + } + if !containsPositiveIntLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run artifact hit count was not positive: %v", second) + } +} + +func TestLoadTypedPackageGraphCustomArtifactCacheReplacedModuleSourceChange(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\treturn dep.New()", + "}", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + load := func(mode Mode) (*LazyLoadResult, error) { + l := New() + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: appRoot, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + + first, err := load(ModeCustom) + if err != nil { + t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) + } + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar l dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(l)", + "}", + "", + }, "\n")) + + custom, err := load(ModeCustom) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + if len(custom.Packages) != 1 { + t.Fatalf("second custom packages len = %d, want 1", len(custom.Packages)) + } + if got := comparableErrors(custom.Packages[0].Errors); len(got) != 0 { + t.Fatalf("second custom load returned errors: %v", got) + } + + fallback, err := load(ModeFallback) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestDiscoveryCacheInvalidatesOnGoModResolutionChange(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceReplaceTargetIsolation(t *testing.T) { + cacheHome := t.TempDir() + artifactDir := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + depOneRoot := filepath.Join(repoOne, "depmod") + appOneRoot := filepath.Join(repoOne, "appmod") + depTwoRoot := filepath.Join(repoTwo, "depmod") + appTwoRoot := filepath.Join(repoTwo, "appmod") + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+cacheHome, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + warm := loadTypedPackageGraphForTest(t, appOneRoot, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomTransitiveShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T { return b.New() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() any { return a.New() }", + "", + }, "\n")) + + first := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T {", + "\tvar logger b.Logger = b.NoopLogger{}", + "\treturn b.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplacePathSwitchInvalidatesCaches(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.TrimSpace(\" two \") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomDiscoveryCacheReplacedSiblingOutsideWorkspace(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + rootLoad := loadRootGraphForTest(t, appRoot, env, []string{"./app"}, ModeCustom) + if rootLoad.Discovery == nil { + t.Fatal("expected discovery snapshot from custom root load") + } + + first := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestDiscoveryCacheInvalidatesOnGeneratedFileSetChange(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "zz_generated.go"), strings.Join([]string{ + "package dep", + "", + "func Generated() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() + dep.Generated() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomBodyOnlyEditWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"before\")", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"after\")", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestLoadTypedPackageGraphCustomReplaceNestedModuleParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(appRoot, "third_party", "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => ./third_party/depmod", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceChainParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + midRoot := filepath.Join(root, "midmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(midRoot, "go.mod"), "module example.com/mid\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require (", + "\texample.com/dep v0.0.0", + "\texample.com/mid v0.0.0", + ")", + "", + "replace example.com/dep => " + depRoot, + "replace example.com/mid => " + midRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/mid\"", + "", + "func Use() string { return mid.Use() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/mid", false) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomGoWorkWorkspaceParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.work"), strings.Join([]string{ + "go 1.19", + "", + "use (", + "\t./appmod", + "\t./depmod", + ")", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return strings.TrimSpace(\" ok \") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceModuleIsolation(t *testing.T) { + cacheHome := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + writeTestFile(t, filepath.Join(repoOne, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoOne, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Message() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoOne, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(repoTwo, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoTwo, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func Message() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoTwo, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+cacheHome) + warm := loadTypedPackageGraphForTest(t, repoOne, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestDiscoveryCacheInvalidatesOnLocalImportChangeEndToEnd(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Base() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ + "package extra", + "", + "func Value() string { return \"extra\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Base() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "func Base() string { return extra.Value() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomLocalShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "type T struct{}", + "", + "func New(Config) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomTransitiveBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func Message() string { return b.Message() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() string { return a.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/a", true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/b", true) +} + +func TestLoadTypedPackageGraphCustomKnownShapeToggleWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct { Name string }", + "", + "func New(Config) string { return \"config\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New(dep.Config{Name: \"a\"}) }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"logger\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomNewShapeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "func New() string { return \"ok\" }", + "", + "func NewWithConfig(Config) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.NewWithConfig(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomFixtureAppWarmMutationParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "base", "base.go"), strings.Join([]string{ + "package base", + "", + "import \"fmt\"", + "", + "func Prefix() string { return fmt.Sprint(\"base:\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\treturn base.Prefix() + dep.Message() + gen.Value()", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/feature\"", + "", + "func Init() string { return feature.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + coldCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(coldCustom.Packages) != 1 || len(coldCustom.Packages[0].Errors) != 0 { + t.Fatalf("cold custom load returned errors: %+v", coldCustom.Packages) + } + coldFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, coldCustom.Packages, coldFallback.Packages, true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/dep", false) + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep2\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated2\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn base.Prefix() + dep.Message(logger) + gen.Value()", + "}", + "", + }, "\n")) + + warmCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + warmFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, warmCustom.Packages, warmFallback.Packages, true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomSequentialMutationsParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "helper", "helper.go"), strings.Join([]string{ + "package helper", + "", + "func Prefix() string { return \"prefix:\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + assertParity := func() { + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) + } + + initial := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(initial.Packages) != 1 || len(initial.Packages[0].Errors) != 0 { + t.Fatalf("initial custom load returned errors: %+v", initial.Packages) + } + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep-body\") }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string { return helper.Prefix() + dep.Message() }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep-shape\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn helper.Prefix() + dep.Message(logger)", + "}", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + assertParity() +} + +func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { + root := t.TempDir() + proxyDir := t.TempDir() + homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") + + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/extdep/pkg\"", + "", + "func Init() string { return pkg.Version() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY="+fileURLForTest(t, proxyDir), + "GOSUMDB=off", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, + ) + runGoModTidyForTest(t, root, env) + + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + if got := comparableErrors(first.Packages[0].Errors); len(got) != 0 { + t.Fatalf("first custom load returned errors: %v", got) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.sum"), "") + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T) { + root := t.TempDir() + proxyDir := t.TempDir() + artifactDir := t.TempDir() + homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") + + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.1.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nimport \"strings\"\n\nfunc Version() string { return strings.TrimSpace(\"v1.1.0\") }\n", + }) + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/extdep/pkg\"", + "", + "func Init() string { return pkg.Version() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY="+fileURLForTest(t, proxyDir), + "GOSUMDB=off", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + runGoModTidyForTest(t, root, env) + + first := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + firstDep := collectGraph(first.Packages)["example.com/extdep/pkg"] + if firstDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg") + } + if !containsPathSubstring(firstDep.CompiledGoFiles, "example.com/extdep@v1.0.0") { + t.Fatalf("first dependency files = %v, want version v1.0.0", firstDep.CompiledGoFiles) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.1.0", + "", + }, "\n")) + runGoModTidyForTest(t, root, env) + + custom := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + fallback := loadPackagesForTest(t, root, env, []string{"./app"}, ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/extdep/pkg", false) + + secondDep := collectGraph(custom.Packages)["example.com/extdep/pkg"] + if secondDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg after version change") + } + if !containsPathSubstring(secondDep.CompiledGoFiles, "example.com/extdep@v1.1.0") { + t.Fatalf("second dependency files = %v, want version v1.1.0", secondDep.CompiledGoFiles) + } +} + +func TestLoaderArtifactKeyExternalChangesWhenExportFileChanges(t *testing.T) { + exportPath := filepath.Join(t.TempDir(), "dep.a") + writeTestFile(t, exportPath, "first export payload") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + Export: exportPath, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, exportPath, "second export payload with different contents") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after export file update: %q", first) + } +} + +func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "dep.go") + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"first\"\n") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + GoFiles: []string{sourcePath}, + CompiledGoFiles: []string{sourcePath}, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"second\"\n") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after external source update without export data: %q", first) + } +} + +func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), "package app\n\nimport \"example.com/dep\"\n\nfunc Use() string { return dep.New() }\n") + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: append(os.Environ(), "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir), + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil { + t.Fatal("expected metadata for example.com/dep") + } + if depMeta.Export == "" { + t.Fatalf("expected export data path for replaced module metadata: %+v", depMeta) + } +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export == "" { + t.Fatalf("expected export-backed metadata for example.com/dep: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "//go:build never", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(first) error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for incomplete replaced module: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "var _ missing", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + meta, err = runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(second) error = %v", err) + } + depMeta = meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for second incomplete replaced module state: %+v", depMeta) + } + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + load := func(env []string) (map[string]*packages.Package, error) { + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: root, + Env: env, + Patterns: []string{"."}, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + return nil, err + } + return collectGraph(got.Packages), nil + } + + base, err := load(os.Environ()) + if err != nil { + t.Fatalf("base load error = %v", err) + } + withArtifactsEnv := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + firstArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("first artifact load error = %v", err) + } + secondArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("second artifact load error = %v", err) + } + if len(base) != len(firstArtifact) { + t.Fatalf("first artifact graph size = %d, want %d", len(firstArtifact), len(base)) + } + if len(base) != len(secondArtifact) { + var missing []string + for path := range base { + if secondArtifact[path] == nil { + missing = append(missing, path) + } + } + sort.Strings(missing) + parents := make(map[string][]string) + for parentPath, pkg := range base { + for impPath := range pkg.Imports { + if secondArtifact[impPath] == nil { + parents[impPath] = append(parents[impPath], parentPath) + } + } + } + parentSummary := make([]string, 0, 5) + for _, path := range missing { + if len(parentSummary) == 5 { + break + } + importers := append([]string(nil), parents[path]...) + sort.Strings(importers) + if len(importers) > 3 { + importers = importers[:3] + } + parentSummary = append(parentSummary, path+" <- "+strings.Join(importers, ",")) + } + if len(missing) > 20 { + missing = missing[:20] + } + secondParent := secondArtifact["github.com/shirou/gopsutil/v4/internal/common"] + secondParentImports := []string(nil) + if secondParent != nil { + secondParentImports = sortedImportPaths(secondParent.Imports) + } + internalCommonParents := append([]string(nil), parents["github.com/shirou/gopsutil/v4/internal/common"]...) + sort.Strings(internalCommonParents) + t.Fatalf("second artifact graph size = %d, want %d; missing sample=%v; parent sample=%v; gopsutil/internal/common parents=%v; gopsutil/internal/common imports on second run=%v", len(secondArtifact), len(base), missing, parentSummary, internalCommonParents, secondParentImports) + } + if compiledFileCount(base) != compiledFileCount(secondArtifact) { + t.Fatalf("second artifact compiled file count = %d, want %d", compiledFileCount(secondArtifact), compiledFileCount(base)) + } +} + +func TestLoadRootGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + + l := New() + custom, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + fallback, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, false) +} + +func TestLoadTypedPackageGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func comparePackageGraphs(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + gotAll := collectGraph(got) + wantAll := collectGraph(want) + if len(gotAll) != len(wantAll) { + t.Fatalf("package graph size = %d, want %d", len(gotAll), len(wantAll)) + } + for path, wantPkg := range wantAll { + gotPkg := gotAll[path] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", path) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", path, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", path, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", path, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", path, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", path, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", path, gotTyped, wantTyped) + } + } + } +} + +func compareRootPackagesOnly(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("root package count = %d, want %d", len(got), len(want)) + } + gotByPath := make(map[string]*packages.Package, len(got)) + for _, pkg := range got { + gotByPath[pkg.PkgPath] = pkg + } + for _, wantPkg := range want { + gotPkg := gotByPath[wantPkg.PkgPath] + if gotPkg == nil { + t.Fatalf("missing root package %q", wantPkg.PkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", wantPkg.PkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", wantPkg.PkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", wantPkg.PkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", wantPkg.PkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", wantPkg.PkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", wantPkg.PkgPath, gotTyped, wantTyped) + } + } + } +} + +func comparePackageByPath(t *testing.T, got []*packages.Package, want []*packages.Package, pkgPath string, requireTyped bool) { + t.Helper() + gotPkg := collectGraph(got)[pkgPath] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", pkgPath) + } + wantPkg := collectGraph(want)[pkgPath] + if wantPkg == nil { + t.Fatalf("missing package %q in fallback graph", pkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", pkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", pkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", pkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", pkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", pkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", pkgPath, gotTyped, wantTyped) + } + } +} + +func collectGraph(roots []*packages.Package) map[string]*packages.Package { + out := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), roots...) + for len(stack) > 0 { + pkg := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if pkg == nil || out[pkg.PkgPath] != nil { + continue + } + out[pkg.PkgPath] = pkg + for _, imp := range pkg.Imports { + stack = append(stack, imp) + } + } + return out +} + +func loadTypedPackageGraphForTest(t *testing.T, wd string, env []string, pkg string, mode Mode) *LazyLoadResult { + return loadTypedPackageGraphWithDiscoveryForTest(t, wd, env, pkg, mode, nil) +} + +func loadPackagesForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *PackageLoadResult { + t.Helper() + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadPackages(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadTypedPackageGraphWithDiscoveryForTest(t *testing.T, wd string, env []string, pkg string, mode Mode, discovery *DiscoverySnapshot) *LazyLoadResult { + t.Helper() + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: wd, + Env: env, + Package: pkg, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + Discovery: discovery, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadRootGraphForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *RootLoadResult { + t.Helper() + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + NeedDeps: true, + Mode: mode, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func compiledFileCount(pkgs map[string]*packages.Package) int { + total := 0 + for _, pkg := range pkgs { + total += len(pkg.CompiledGoFiles) + } + return total +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + aCopy := append([]string(nil), a...) + bCopy := append([]string(nil), b...) + for i := range aCopy { + aCopy[i] = normalizePathForCompare(aCopy[i]) + } + for i := range bCopy { + bCopy[i] = normalizePathForCompare(bCopy[i]) + } + sort.Strings(aCopy) + sort.Strings(bCopy) + for i := range aCopy { + if aCopy[i] != bCopy[i] { + return false + } + } + return true +} + +func equalImportPaths(a, b map[string]*packages.Package) bool { + return equalStrings(sortedImportPaths(a), sortedImportPaths(b)) +} + +func sortedImportPaths(m map[string]*packages.Package) []string { + out := make([]string, 0, len(m)) + for path := range m { + out = append(out, path) + } + sort.Strings(out) + return out +} + +func containsPathSubstring(paths []string, needle string) bool { + for _, path := range paths { + if strings.Contains(normalizePathForCompare(path), needle) { + return true + } + } + return false +} + +func runGoModTidyForTest(t *testing.T, wd string, env []string) { + t.Helper() + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = wd + cmd.Env = env + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go mod tidy in %q error = %v: %s", wd, err, out) + } +} + +func writeModuleProxyVersion(t *testing.T, proxyDir string, modulePath string, version string, files map[string]string) { + t.Helper() + base := filepath.Join(proxyDir, filepath.FromSlash(modulePath), "@v") + if err := os.MkdirAll(base, 0o755); err != nil { + t.Fatalf("mkdir proxy dir: %v", err) + } + listPath := filepath.Join(base, "list") + appendLineIfMissing(t, listPath, version) + + modFile := "module " + modulePath + "\n\ngo 1.19\n" + writeTestFile(t, filepath.Join(base, version+".mod"), modFile) + writeTestFile(t, filepath.Join(base, version+".info"), fmt.Sprintf("{\"Version\":%q,\"Time\":\"2024-01-01T00:00:00Z\"}\n", version)) + + zipPath := filepath.Join(base, version+".zip") + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("create proxy zip: %v", err) + } + defer zipFile.Close() + + zw := zip.NewWriter(zipFile) + moduleRoot := modulePath + "@" + version + writeZipFile := func(name string, contents string) { + w, err := zw.Create(moduleRoot + "/" + filepath.ToSlash(name)) + if err != nil { + t.Fatalf("create zip entry %q: %v", name, err) + } + if _, err := w.Write([]byte(contents)); err != nil { + t.Fatalf("write zip entry %q: %v", name, err) + } + } + writeZipFile("go.mod", modFile) + for name, contents := range files { + writeZipFile(name, contents) + } + if err := zw.Close(); err != nil { + t.Fatalf("close proxy zip: %v", err) + } +} + +func appendLineIfMissing(t *testing.T, path string, line string) { + t.Helper() + existing, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("read %q: %v", path, err) + } + for _, existingLine := range strings.Split(strings.TrimSpace(string(existing)), "\n") { + if existingLine == line { + return + } + } + content := string(existing) + if strings.TrimSpace(content) != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += line + "\n" + writeTestFile(t, path, content) +} + +func tempCacheDirForTest(t *testing.T, pattern string) string { + t.Helper() + dir, err := os.MkdirTemp("", pattern) + if err != nil { + t.Fatalf("MkdirTemp(%q) error = %v", pattern, err) + } + t.Cleanup(func() { + _ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + _ = os.Chmod(path, 0o755) + return nil + } + _ = os.Chmod(path, 0o644) + return nil + }) + _ = os.RemoveAll(dir) + }) + return dir +} + +func fileURLForTest(t *testing.T, path string) string { + t.Helper() + slashed := filepath.ToSlash(path) + if !strings.HasPrefix(slashed, "/") { + slashed = "/" + slashed + } + return "file://" + slashed +} + +type importerFuncForTest func(string) (*types.Package, error) + +func (f importerFuncForTest) Import(path string) (*types.Package, error) { + return f(path) +} + +func mustParseFile(t *testing.T, fset *token.FileSet, filename, src string) *ast.File { + t.Helper() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + if err != nil { + t.Fatalf("ParseFile(%q) error = %v", filename, err) + } + return file +} + +func normalizePathForCompare(path string) string { + if path == "" { + return "" + } + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.ToSlash(filepath.Clean(resolved)) + } + return filepath.ToSlash(filepath.Clean(path)) +} + +func comparableErrors(errs []packages.Error) []string { + seen := make(map[string]struct{}, len(errs)) + out := make([]string, 0, len(errs)) + add := func(value string) { + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + for _, err := range errs { + if strings.HasPrefix(err.Msg, "# ") { + for _, value := range expandSummaryDiagnostics(err.Msg) { + add(value) + } + continue + } + pos := normalizeErrorPos(err.Pos) + add(pos + "|" + err.Msg) + } + sort.Strings(out) + return out +} + +func hasPrefixLabel(labels []string, prefix string) bool { + for _, label := range labels { + if strings.HasPrefix(label, prefix) { + return true + } + } + return false +} + +func containsPositiveIntLabel(labels []string, prefix string) bool { + for _, label := range labels { + if !strings.HasPrefix(label, prefix) { + continue + } + value := strings.TrimPrefix(label, prefix) + n, err := strconv.Atoi(value) + if err == nil && n > 0 { + return true + } + } + return false +} + +func normalizeErrorPos(pos string) string { + if pos == "" || pos == "-" { + return pos + } + last := strings.LastIndex(pos, ":") + if last == -1 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + prev := strings.LastIndex(pos[:last], ":") + if prev == -1 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + path := shortenComparablePath(normalizePathForCompare(pos[:prev])) + return path + pos[prev:] +} + +func expandSummaryDiagnostics(msg string) []string { + lines := strings.Split(msg, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines[1:] { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if parts := strings.SplitN(line, ": ", 2); len(parts) == 2 { + pos := normalizeErrorPos(parts[0]) + out = append(out, pos+"|"+parts[1]) + continue + } + out = append(out, line) + } + return out +} + +func shortenComparablePath(path string) string { + path = filepath.Clean(path) + parts := strings.Split(path, string(filepath.Separator)) + if len(parts) >= 2 { + return filepath.Join(parts[len(parts)-2], parts[len(parts)-1]) + } + return path +} + +func writeTestFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/internal/loader/mode.go b/internal/loader/mode.go new file mode 100644 index 0000000..b08710b --- /dev/null +++ b/internal/loader/mode.go @@ -0,0 +1,38 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import "strings" + +const ModeEnvVar = "WIRE_LOADER_MODE" + +func ModeFromEnv(env []string) Mode { + mode := ModeAuto + for _, entry := range env { + name, value, ok := strings.Cut(entry, "=") + if !ok || name != ModeEnvVar { + continue + } + switch strings.ToLower(strings.TrimSpace(value)) { + case string(ModeCustom): + mode = ModeCustom + case string(ModeFallback): + mode = ModeFallback + case "", string(ModeAuto): + mode = ModeAuto + } + } + return mode +} diff --git a/internal/loader/timing.go b/internal/loader/timing.go new file mode 100644 index 0000000..1ae9ccd --- /dev/null +++ b/internal/loader/timing.go @@ -0,0 +1,54 @@ +package loader + +import ( + "context" + "fmt" + "time" +) + +type timingLogger func(string, time.Duration) + +type timingKey struct{} + +func WithTiming(ctx context.Context, logf func(string, time.Duration)) context.Context { + if logf == nil { + return ctx + } + return context.WithValue(ctx, timingKey{}, timingLogger(logf)) +} + +func timing(ctx context.Context) timingLogger { + if ctx == nil { + return nil + } + if v := ctx.Value(timingKey{}); v != nil { + if t, ok := v.(timingLogger); ok { + return t + } + } + return nil +} + +func logTiming(ctx context.Context, label string, start time.Time) { + if t := timing(ctx); t != nil { + t(label, time.Since(start)) + } +} + +func logDuration(ctx context.Context, label string, d time.Duration) { + if t := timing(ctx); t != nil { + t(label, d) + } +} + +func logInt(ctx context.Context, label string, v int) { + if t := timing(ctx); t != nil { + t(fmt.Sprintf("%s=%d", label, v), 0) + } +} + +func debugf(ctx context.Context, format string, args ...interface{}) { + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) + } +} diff --git a/internal/runtests.sh b/internal/runtests.sh index 28877c1..905e319 100755 --- a/internal/runtests.sh +++ b/internal/runtests.sh @@ -16,6 +16,14 @@ # https://coderwall.com/p/fkfaqq/safer-bash-scripts-with-set-euxo-pipefail set -euo pipefail +tmp_root="${TMPDIR:-${RUNNER_TEMP:-}}" +if [[ -z "${tmp_root}" ]]; then + tmp_root="$(mktemp -d)" +fi + +export GOCACHE="${GOCACHE:-${tmp_root}/gocache}" +export GOMODCACHE="${GOMODCACHE:-${tmp_root}/gomodcache}" + if [[ $# -gt 0 ]]; then echo "usage: runtests.sh" 1>&2 exit 64 @@ -34,7 +42,10 @@ fi echo echo "Ensuring .go files are formatted with gofmt -s..." -mapfile -t go_files < <(find . -name '*.go' -type f | grep -v testdata) +go_files=() +while IFS= read -r file; do + go_files+=("$file") +done < <(find . -name '*.go' -type f | grep -v testdata) DIFF="$(gofmt -s -d "${go_files[@]}")" if [ -n "$DIFF" ]; then echo "FAIL: please run gofmt -s and commit the result" diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go deleted file mode 100644 index d65de26..0000000 --- a/internal/wire/cache_coverage_test.go +++ /dev/null @@ -1,1067 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "errors" - "io/fs" - "os" - "path/filepath" - "sort" - "sync" - "testing" - - "golang.org/x/tools/go/packages" -) - -type cacheHookState struct { - osCreateTemp func(string, string) (*os.File, error) - osMkdirAll func(string, os.FileMode) error - osReadFile func(string) ([]byte, error) - osRemove func(string) error - osRemoveAll func(string) error - osRename func(string, string) error - osStat func(string) (os.FileInfo, error) - osTempDir func() string - jsonMarshal func(any) ([]byte, error) - jsonUnmarshal func([]byte, any) error - extraCachePathsFunc func(string) []string - cacheKeyForPackage func(*packages.Package, *GenerateOptions) (string, error) - detectOutputDir func([]string) (string, error) - buildCacheFiles func([]string) ([]cacheFile, error) - buildCacheFilesFrom func([]cacheFile) ([]cacheFile, error) - rootPackageFiles func(*packages.Package) []string - hashFiles func([]string) (string, error) -} - -var cacheHooksMu sync.Mutex - -func lockCacheHooks(t *testing.T) { - t.Helper() - cacheHooksMu.Lock() - t.Cleanup(func() { - cacheHooksMu.Unlock() - }) -} - -func saveCacheHooks() cacheHookState { - return cacheHookState{ - osCreateTemp: osCreateTemp, - osMkdirAll: osMkdirAll, - osReadFile: osReadFile, - osRemove: osRemove, - osRemoveAll: osRemoveAll, - osRename: osRename, - osStat: osStat, - osTempDir: osTempDir, - jsonMarshal: jsonMarshal, - jsonUnmarshal: jsonUnmarshal, - extraCachePathsFunc: extraCachePathsFunc, - cacheKeyForPackage: cacheKeyForPackageFunc, - detectOutputDir: detectOutputDirFunc, - buildCacheFiles: buildCacheFilesFunc, - buildCacheFilesFrom: buildCacheFilesFromMetaFunc, - rootPackageFiles: rootPackageFilesFunc, - hashFiles: hashFilesFunc, - } -} - -func restoreCacheHooks(state cacheHookState) { - osCreateTemp = state.osCreateTemp - osMkdirAll = state.osMkdirAll - osReadFile = state.osReadFile - osRemove = state.osRemove - osRemoveAll = state.osRemoveAll - osRename = state.osRename - osStat = state.osStat - osTempDir = state.osTempDir - jsonMarshal = state.jsonMarshal - jsonUnmarshal = state.jsonUnmarshal - extraCachePathsFunc = state.extraCachePathsFunc - cacheKeyForPackageFunc = state.cacheKeyForPackage - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = state.buildCacheFiles - buildCacheFilesFromMetaFunc = state.buildCacheFilesFrom - rootPackageFilesFunc = state.rootPackageFiles - hashFilesFunc = state.hashFiles -} - -func writeTempFile(t *testing.T, dir, name, content string) string { - t.Helper() - path := filepath.Join(dir, name) - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile(%s) failed: %v", path, err) - } - return path -} - -func cloneManifest(src *cacheManifest) *cacheManifest { - if src == nil { - return nil - } - dst := *src - if src.Patterns != nil { - dst.Patterns = append([]string(nil), src.Patterns...) - } - if src.ExtraFiles != nil { - dst.ExtraFiles = append([]cacheFile(nil), src.ExtraFiles...) - } - if src.Packages != nil { - dst.Packages = make([]manifestPackage, len(src.Packages)) - for i, pkg := range src.Packages { - dstPkg := pkg - if pkg.Files != nil { - dstPkg.Files = append([]cacheFile(nil), pkg.Files...) - } - if pkg.RootFiles != nil { - dstPkg.RootFiles = append([]cacheFile(nil), pkg.RootFiles...) - } - dst.Packages[i] = dstPkg - } - } - return &dst -} - -func TestCacheStoreReadWrite(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if got := CacheDir(); got == "" { - t.Fatal("expected CacheDir to return a value") - } - - key := "cache-store" - want := []byte("content") - writeCache(key, want) - - got, ok := readCache(key) - if !ok { - t.Fatal("expected cache hit") - } - if !bytes.Equal(got, want) { - t.Fatalf("cache content mismatch: got %q, want %q", got, want) - } - if err := ClearCache(); err != nil { - t.Fatalf("ClearCache failed: %v", err) - } - if _, ok := readCache(key); ok { - t.Fatal("expected cache miss after clear") - } -} - -func TestCacheStoreReadError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osReadFile = func(string) ([]byte, error) { - return nil, errors.New("boom") - } - if _, ok := readCache("missing"); ok { - t.Fatal("expected cache miss on read error") - } -} - -func TestCacheStoreWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - t.Run("mkdir", func(t *testing.T) { - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCache("mkdir", []byte("data")) - }) - - t.Run("create", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { - return nil, errors.New("create") - } - writeCache("create", []byte("data")) - }) - - t.Run("write", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCache("write", []byte("data")) - }) - - t.Run("rename-exist", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return fs.ErrExist - } - writeCache("exist", []byte("data")) - }) - - t.Run("rename", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return errors.New("rename") - } - writeCache("rename", []byte("data")) - }) -} - -func TestCacheDirError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osRemoveAll = func(string) error { return errors.New("remove") } - if err := ClearCache(); err == nil { - t.Fatal("expected ClearCache error") - } -} - -func TestPackageFiles(t *testing.T) { - tempDir := t.TempDir() - rootFile := writeTempFile(t, tempDir, "root.go", "package root\n") - childFile := writeTempFile(t, tempDir, "child.go", "package child\n") - - child := &packages.Package{ - PkgPath: "example.com/child", - CompiledGoFiles: []string{childFile}, - } - root := &packages.Package{ - PkgPath: "example.com/root", - GoFiles: []string{rootFile}, - Imports: map[string]*packages.Package{ - "child": child, - "dup": child, - "nil": nil, - }, - } - got := packageFiles(root) - sort.Strings(got) - if len(got) != 2 { - t.Fatalf("expected 2 files, got %d", len(got)) - } - if got[0] != childFile || got[1] != rootFile { - t.Fatalf("unexpected files: %v", got) - } -} - -func TestCacheKeyEmptyPackage(t *testing.T) { - key, err := cacheKeyForPackage(&packages.Package{PkgPath: "example.com/empty"}, &GenerateOptions{}) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if key != "" { - t.Fatalf("expected empty cache key, got %q", key) - } -} - -func TestCacheKeyMetaHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - metaKey := cacheMetaKey(pkg, opts) - writeCacheMeta(metaKey, meta) - - got, err := cacheKeyForPackage(pkg, opts) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if got != contentHash { - t.Fatalf("cache key mismatch: got %q, want %q", got, contentHash) - } -} - -func TestCacheKeyErrorPaths(t *testing.T) { - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(t.TempDir(), "missing.go")}, - } - if _, err := cacheKeyForPackage(pkg, &GenerateOptions{}); err == nil { - t.Fatal("expected cacheKeyForPackage error") - } - if _, err := buildCacheFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected buildCacheFiles error") - } - if _, err := contentHashForPaths("example.com/missing", &GenerateOptions{}, []string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected contentHashForPaths error") - } - if _, err := hashFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected hashFiles error") - } - if got, err := hashFiles(nil); err != nil || got != "" { - t.Fatalf("expected empty hashFiles result, got %q err=%v", got, err) - } -} - -func TestCacheMetaMatches(t *testing.T) { - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "meta.go", "package meta\n") - pkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - if !cacheMetaMatches(meta, pkg, opts, files) { - t.Fatal("expected cacheMetaMatches to succeed") - } - badVersion := *meta - badVersion.Version = "nope" - if cacheMetaMatches(&badVersion, pkg, opts, files) { - t.Fatal("expected version mismatch") - } - badPkg := *meta - badPkg.PkgPath = "example.com/other" - if cacheMetaMatches(&badPkg, pkg, opts, files) { - t.Fatal("expected pkg mismatch") - } - badHeader := *meta - badHeader.HeaderHash = "bad" - if cacheMetaMatches(&badHeader, pkg, opts, files) { - t.Fatal("expected header mismatch") - } - shortFiles := *meta - shortFiles.Files = nil - if cacheMetaMatches(&shortFiles, pkg, opts, files) { - t.Fatal("expected file count mismatch") - } - fileMismatch := *meta - fileMismatch.Files = append([]cacheFile(nil), meta.Files...) - fileMismatch.Files[0].Size++ - if cacheMetaMatches(&fileMismatch, pkg, opts, files) { - t.Fatal("expected file metadata mismatch") - } - pkgNoRoot := &packages.Package{PkgPath: pkg.PkgPath} - if cacheMetaMatches(meta, pkgNoRoot, opts, files) { - t.Fatal("expected missing root files") - } - noRootHash := *meta - noRootHash.RootHash = "" - if cacheMetaMatches(&noRootHash, pkg, opts, files) { - t.Fatal("expected empty root hash mismatch") - } - missingRootPkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{filepath.Join(tempDir, "missing.go")}, - } - if cacheMetaMatches(meta, missingRootPkg, opts, files) { - t.Fatal("expected root hash error") - } - badRoot := *meta - badRoot.RootHash = "bad" - if cacheMetaMatches(&badRoot, pkg, opts, files) { - t.Fatal("expected root hash mismatch") - } - emptyContent := *meta - emptyContent.ContentHash = "" - if cacheMetaMatches(&emptyContent, pkg, opts, files) { - t.Fatal("expected empty content hash mismatch") - } - - if cacheMetaMatches(meta, pkg, opts, []string{filepath.Join(tempDir, "missing.go")}) { - t.Fatal("expected buildCacheFiles error") - } -} - -func TestCacheMetaReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readCacheMeta("missing"); ok { - t.Fatal("expected cache meta miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readCacheMeta("bad-json"); ok { - t.Fatal("expected cache meta miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCacheMeta("mkdir", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeCacheMeta("marshal", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeCacheMeta("create", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCacheMeta("write", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeCacheMeta("rename", &cacheMeta{}) -} - -func TestManifestReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readManifest("missing"); ok { - t.Fatal("expected manifest miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readManifest("bad-json"); ok { - t.Fatal("expected manifest miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeManifestFile("mkdir", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeManifestFile("marshal", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeManifestFile("create", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeManifestFile("write", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeManifestFile("rename", &cacheManifest{}) -} - -func TestManifestKeyHelpers(t *testing.T) { - if got := manifestKeyFromManifest(nil); got != "" { - t.Fatalf("expected empty manifest key, got %q", got) - } - env := []string{"A=B"} - opts := &GenerateOptions{ - Tags: "tags", - PrefixOutputFile: "prefix", - Header: []byte("header"), - } - manifest := &cacheManifest{ - WD: t.TempDir(), - EnvHash: envHash(env), - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Patterns: []string{"./a", "./b"}, - } - got := manifestKeyFromManifest(manifest) - want := manifestKey(manifest.WD, env, manifest.Patterns, opts) - if got != want { - t.Fatalf("manifest key mismatch: got %q, want %q", got, want) - } -} - -func TestReadManifestResultsPaths(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected no manifest") - } - - key := manifestKey(wd, env, patterns, opts) - invalid := &cacheManifest{Version: cacheVersion, WD: wd, EnvHash: "", Packages: nil} - writeManifestFile(key, invalid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected invalid manifest miss") - } - - file := writeTempFile(t, wd, "wire.go", "package app\n") - pkg := &packages.Package{ - PkgPath: "example.com/app", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, valid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected cache miss without content") - } - writeCache(contentHash, []byte("wire")) - if results, ok := readManifestResults(wd, env, patterns, opts); !ok || len(results) != 1 { - t.Fatalf("expected manifest cache hit, got ok=%v results=%d", ok, len(results)) - } -} - -func TestWriteManifestBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - writeManifest(wd, env, patterns, opts, nil) - - writeManifest(wd, env, patterns, opts, []*packages.Package{nil}) - - writeManifest(wd, env, patterns, opts, []*packages.Package{{PkgPath: "example.com/empty"}}) - - missingFilePkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(wd, "missing.go")}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{missingFilePkg}) - - conflictDir := t.TempDir() - fileA := writeTempFile(t, conflictDir, "a.go", "package a\n") - fileB := writeTempFile(t, t.TempDir(), "b.go", "package b\n") - conflictPkg := &packages.Package{ - PkgPath: "example.com/conflict", - GoFiles: []string{fileA, fileB}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{conflictPkg}) - - okFile := writeTempFile(t, wd, "ok.go", "package ok\n") - okPkg := &packages.Package{ - PkgPath: "example.com/ok", - GoFiles: []string{okFile}, - } - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", errors.New("cache key") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", nil - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "hash", nil - } - detectOutputDirFunc = func([]string) (string, error) { - return "", errors.New("output") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - return nil, errors.New("build") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - call := 0 - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - call++ - if call > 1 { - return nil, errors.New("root") - } - return []cacheFile{{Path: okFile}}, nil - } - rootPackageFilesFunc = func(*packages.Package) []string { - return []string{okFile} - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - buildCacheFilesFunc = state.buildCacheFiles - hashFilesFunc = func([]string) (string, error) { - return "", errors.New("hash") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - statCalls := 0 - osStat = func(name string) (os.FileInfo, error) { - statCalls++ - if statCalls > 3 { - return nil, errors.New("stat") - } - return state.osStat(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - readCalls := 0 - osReadFile = func(name string) ([]byte, error) { - readCalls++ - if readCalls > 2 { - return nil, errors.New("read") - } - return state.osReadFile(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) -} - -func TestManifestValidationAndExtras(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - if manifestValid(nil) { - t.Fatal("expected nil manifest invalid") - } - if manifestValid(&cacheManifest{Version: "bad"}) { - t.Fatal("expected version mismatch") - } - if manifestValid(&cacheManifest{Version: cacheVersion}) { - t.Fatal("expected missing env hash") - } - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "valid.go", "package valid\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/valid", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: nil, - } - if !manifestValid(valid) { - t.Fatal("expected valid manifest") - } - - invalidExtra := cloneManifest(valid) - invalidExtra.ExtraFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidExtra) { - t.Fatal("expected invalid extra files") - } - - extraMismatch := cloneManifest(valid) - extraMismatch.ExtraFiles = []cacheFile{files[0]} - extraMismatch.ExtraFiles[0].Size++ - if manifestValid(extraMismatch) { - t.Fatal("expected extra file metadata mismatch") - } - - invalidPkg := cloneManifest(valid) - invalidPkg.Packages[0].ContentHash = "" - if manifestValid(invalidPkg) { - t.Fatal("expected invalid content hash") - } - - invalidRoot := cloneManifest(valid) - invalidRoot.Packages[0].RootHash = "" - if manifestValid(invalidRoot) { - t.Fatal("expected invalid root hash") - } - - invalidFiles := cloneManifest(valid) - invalidFiles.Packages[0].Files = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidFiles) { - t.Fatal("expected invalid package files") - } - - fileMismatch := cloneManifest(valid) - fileMismatch.Packages[0].Files = []cacheFile{files[0]} - fileMismatch.Packages[0].Files[0].Size++ - if manifestValid(fileMismatch) { - t.Fatal("expected package file mismatch") - } - - invalidRootFiles := cloneManifest(valid) - invalidRootFiles.Packages[0].RootFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidRootFiles) { - t.Fatal("expected invalid root files") - } - - rootMismatch := cloneManifest(valid) - rootMismatch.Packages[0].RootFiles = []cacheFile{files[0]} - rootMismatch.Packages[0].RootFiles[0].Size++ - if manifestValid(rootMismatch) { - t.Fatal("expected root file mismatch") - } - - emptyRoot := cloneManifest(valid) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - badHash := cloneManifest(valid) - badHash.Packages[0].RootHash = "bad" - if manifestValid(badHash) { - t.Fatal("expected root hash mismatch") - } - - if _, err := buildCacheFilesFromMeta([]cacheFile{{Path: filepath.Join(tempDir, "missing.go")}}); err == nil { - t.Fatal("expected buildCacheFilesFromMeta error") - } - - extraCachePathsFunc = func(string) []string { - return []string{file, file, filepath.Join(tempDir, "missing.go")} - } - extras := extraCacheFiles(tempDir) - if len(extras) != 1 { - t.Fatalf("expected 1 extra file, got %d", len(extras)) - } - - extraCachePathsFunc = func(string) []string { return nil } - if extras := extraCacheFiles(tempDir); extras != nil { - t.Fatal("expected nil extras") - } - - extraCachePathsFunc = func(string) []string { return []string{file, writeTempFile(t, tempDir, "go.sum", "sum\n")} } - if extras := extraCacheFiles(tempDir); len(extras) < 2 { - t.Fatalf("expected extras to include two files, got %v", extras) - } -} - -func TestExtraCachePaths(t *testing.T) { - tempDir := t.TempDir() - rootMod := writeTempFile(t, tempDir, "go.mod", "module example.com/root\n") - writeTempFile(t, tempDir, "go.sum", "sum\n") - nested := filepath.Join(tempDir, "nested", "dir") - if err := os.MkdirAll(nested, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - paths := extraCachePaths(nested) - if len(paths) < 2 { - t.Fatalf("expected extra cache paths, got %v", paths) - } - found := false - for _, path := range paths { - if path == rootMod { - found = true - break - } - } - if !found { - t.Fatalf("expected %s in paths: %v", rootMod, paths) - } - if got := sortedStrings(nil); got != nil { - t.Fatal("expected nil for empty sortedStrings") - } - if got := envHash(nil); got != "" { - t.Fatal("expected empty env hash") - } -} - -func TestRootPackageFiles(t *testing.T) { - if rootPackageFiles(nil) != nil { - t.Fatal("expected nil root files for nil package") - } - tempDir := t.TempDir() - compiled := writeTempFile(t, tempDir, "compiled.go", "package compiled\n") - pkg := &packages.Package{ - PkgPath: "example.com/compiled", - CompiledGoFiles: []string{compiled}, - } - got := rootPackageFiles(pkg) - if len(got) != 1 || got[0] != compiled { - t.Fatalf("unexpected compiled files: %v", got) - } -} - -func TestAddExtraCachePath(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "go.mod", "module example.com\n") - var paths []string - seen := make(map[string]struct{}) - addExtraCachePath(&paths, seen, file) - addExtraCachePath(&paths, seen, file) - if len(paths) != 1 { - t.Fatalf("expected 1 path, got %d", len(paths)) - } - addExtraCachePath(&paths, seen, filepath.Join(tempDir, "missing.go")) - if len(paths) != 1 { - t.Fatalf("unexpected extra path append: %v", paths) - } -} - -func TestManifestValidHookBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "hook.go", "package hook\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - base := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/hook", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: []cacheFile{files[0]}, - } - - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == files[0].Path { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(base) { - t.Fatal("expected extra file length mismatch") - } - - restoreCacheHooks(state) - emptyRoot := cloneManifest(base) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return nil, errors.New("pkg files") - } - return buildCacheFilesFromMeta(in) - } - noExtra := cloneManifest(base) - noExtra.ExtraFiles = nil - if manifestValid(noExtra) { - t.Fatal("expected pkg files error") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected pkg files length mismatch") - } - - restoreCacheHooks(state) - call := 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return nil, errors.New("root files") - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files error") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files length mismatch") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{{Path: file, Size: files[0].Size + 1}}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files mismatch") - } -} diff --git a/internal/wire/cache_generate_test.go b/internal/wire/cache_generate_test.go deleted file mode 100644 index d009f73..0000000 --- a/internal/wire/cache_generate_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "sort" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateUsesManifestCache(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - file := filepath.Join(wd, "provider.go") - if err := os.WriteFile(file, []byte("package p\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - key := manifestKey(wd, env, patterns, opts) - - pkg := &packages.Package{ - PkgPath: "example.com/p", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles root error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - - manifest := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, manifest) - writeCache(contentHash, []byte("wire")) - - results, errs := Generate(context.Background(), wd, env, patterns, opts) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(results) != 1 || string(results[0].Content) != "wire" { - t.Fatalf("unexpected cached results: %+v", results) - } -} diff --git a/internal/wire/cache_key.go b/internal/wire/cache_key.go deleted file mode 100644 index 2aa8881..0000000 --- a/internal/wire/cache_key.go +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -// cacheVersion is the schema/version identifier for cache entries. -const cacheVersion = "wire-cache-v3" - -// cacheFile captures file metadata used to validate cached content. -type cacheFile struct { - Path string `json:"path"` - Size int64 `json:"size"` - ModTime int64 `json:"mod_time"` -} - -// cacheMeta tracks inputs and outputs for a single package cache entry. -type cacheMeta struct { - Version string `json:"version"` - PkgPath string `json:"pkg_path"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootHash string `json:"root_hash"` -} - -// cacheKeyForPackage returns the content hash for a package, if cacheable. -func cacheKeyForPackage(pkg *packages.Package, opts *GenerateOptions) (string, error) { - files := packageFiles(pkg) - if len(files) == 0 { - return "", nil - } - sort.Strings(files) - metaKey := cacheMetaKey(pkg, opts) - if meta, ok := readCacheMeta(metaKey); ok { - if cacheMetaMatches(meta, pkg, opts, files) { - return meta.ContentHash, nil - } - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - return "", err - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - return "", err - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - return "", err - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - writeCacheMeta(metaKey, meta) - return contentHash, nil -} - -// packageFiles returns the transitive Go files for a package graph. -func packageFiles(root *packages.Package) []string { - seen := make(map[string]struct{}) - var files []string - stack := []*packages.Package{root} - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil { - continue - } - if _, ok := seen[p.PkgPath]; ok { - continue - } - seen[p.PkgPath] = struct{}{} - if len(p.CompiledGoFiles) > 0 { - files = append(files, p.CompiledGoFiles...) - } else if len(p.GoFiles) > 0 { - files = append(files, p.GoFiles...) - } - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return files -} - -// cacheMetaKey builds the key for a package's cache metadata entry. -func cacheMetaKey(pkg *packages.Package, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkg.PkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// cacheMetaPath returns the on-disk path for a cache metadata key. -func cacheMetaPath(key string) string { - return filepath.Join(cacheDir(), key+".json") -} - -// readCacheMeta loads a cached metadata entry if it exists. -func readCacheMeta(key string) (*cacheMeta, bool) { - data, err := osReadFile(cacheMetaPath(key)) - if err != nil { - return nil, false - } - var meta cacheMeta - if err := jsonUnmarshal(data, &meta); err != nil { - return nil, false - } - return &meta, true -} - -// writeCacheMeta persists cache metadata to disk. -func writeCacheMeta(key string, meta *cacheMeta) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(meta) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".meta-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheMetaPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheMetaMatches reports whether metadata matches the current package inputs. -func cacheMetaMatches(meta *cacheMeta, pkg *packages.Package, opts *GenerateOptions, files []string) bool { - if meta.Version != cacheVersion { - return false - } - if meta.PkgPath != pkg.PkgPath || meta.Tags != opts.Tags || meta.Prefix != opts.PrefixOutputFile { - return false - } - if meta.HeaderHash != headerHash(opts.Header) { - return false - } - if len(meta.Files) != len(files) { - return false - } - current, err := buildCacheFiles(files) - if err != nil { - return false - } - for i := range meta.Files { - if meta.Files[i] != current[i] { - return false - } - } - rootFiles := rootPackageFiles(pkg) - if len(rootFiles) == 0 || meta.RootHash == "" { - return false - } - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil || rootHash != meta.RootHash { - return false - } - return meta.ContentHash != "" -} - -// buildCacheFiles converts file paths into cache metadata entries. -func buildCacheFiles(files []string) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, name := range files { - info, err := osStat(name) - if err != nil { - return nil, err - } - out = append(out, cacheFile{ - Path: filepath.Clean(name), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }) - } - return out, nil -} - -// headerHash returns a stable hash of the generated header content. -func headerHash(header []byte) string { - if len(header) == 0 { - return "" - } - sum := sha256.Sum256(header) - return fmt.Sprintf("%x", sum[:]) -} - -// contentHashForFiles hashes the current package inputs using file paths. -func contentHashForFiles(pkg *packages.Package, opts *GenerateOptions, files []string) (string, error) { - return contentHashForPaths(pkg.PkgPath, opts, files) -} - -// contentHashForPaths hashes the provided file contents and options. -func contentHashForPaths(pkgPath string, opts *GenerateOptions, files []string) (string, error) { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - -// rootPackageFiles returns the direct Go files for the root package. -func rootPackageFiles(pkg *packages.Package) []string { - if pkg == nil { - return nil - } - if len(pkg.CompiledGoFiles) > 0 { - return append([]string(nil), pkg.CompiledGoFiles...) - } - if len(pkg.GoFiles) > 0 { - return append([]string(nil), pkg.GoFiles...) - } - return nil -} - -// hashFiles returns a combined content hash for the provided paths. -func hashFiles(files []string) (string, error) { - if len(files) == 0 { - return "", nil - } - h := sha256.New() - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} diff --git a/internal/wire/cache_manifest.go b/internal/wire/cache_manifest.go deleted file mode 100644 index 127aa55..0000000 --- a/internal/wire/cache_manifest.go +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -// cacheManifest stores per-run cache metadata for generated packages. -type cacheManifest struct { - Version string `json:"version"` - WD string `json:"wd"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - EnvHash string `json:"env_hash"` - Patterns []string `json:"patterns"` - Packages []manifestPackage `json:"packages"` - ExtraFiles []cacheFile `json:"extra_files"` -} - -// manifestPackage captures cached output for a single package. -type manifestPackage struct { - PkgPath string `json:"pkg_path"` - OutputPath string `json:"output_path"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootFiles []cacheFile `json:"root_files"` - RootHash string `json:"root_hash"` -} - -var extraCachePathsFunc = extraCachePaths - -// readManifestResults loads cached generation results if still valid. -func readManifestResults(wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { - key := manifestKey(wd, env, patterns, opts) - manifest, ok := readManifest(key) - if !ok { - return nil, false - } - if !manifestValid(manifest) { - return nil, false - } - results := make([]GenerateResult, 0, len(manifest.Packages)) - for _, pkg := range manifest.Packages { - content, ok := readCache(pkg.ContentHash) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: pkg.PkgPath, - OutputPath: pkg.OutputPath, - Content: content, - }) - } - return results, true -} - -// writeManifest persists cache metadata for a successful run. -func writeManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) { - if len(pkgs) == 0 { - return - } - key := manifestKey(wd, env, patterns, opts) - manifest := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - } - manifest.ExtraFiles = extraCacheFiles(wd) - for _, pkg := range pkgs { - if pkg == nil { - continue - } - files := packageFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - contentHash, err := cacheKeyForPackageFunc(pkg, opts) - if err != nil || contentHash == "" { - continue - } - outDir, err := detectOutputDirFunc(pkg.GoFiles) - if err != nil { - continue - } - outputPath := filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - metaFiles, err := buildCacheFilesFunc(files) - if err != nil { - continue - } - rootFiles := rootPackageFilesFunc(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFilesFunc(rootFiles) - if err != nil { - continue - } - rootHash, err := hashFilesFunc(rootFiles) - if err != nil { - continue - } - manifest.Packages = append(manifest.Packages, manifestPackage{ - PkgPath: pkg.PkgPath, - OutputPath: outputPath, - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }) - } - writeManifestFile(key, manifest) -} - -// manifestKey builds the cache key for a given run configuration. -func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, p := range sortedStrings(patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// manifestKeyFromManifest rebuilds the cache key from stored metadata. -func manifestKeyFromManifest(manifest *cacheManifest) string { - if manifest == nil { - return "" - } - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(manifest.WD))) - h.Write([]byte{0}) - h.Write([]byte(manifest.EnvHash)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Tags)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Prefix)) - h.Write([]byte{0}) - h.Write([]byte(manifest.HeaderHash)) - h.Write([]byte{0}) - for _, p := range sortedStrings(manifest.Patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// readManifest loads the cached manifest by key. -func readManifest(key string) (*cacheManifest, bool) { - data, err := osReadFile(cacheManifestPath(key)) - if err != nil { - return nil, false - } - var manifest cacheManifest - if err := jsonUnmarshal(data, &manifest); err != nil { - return nil, false - } - return &manifest, true -} - -// writeManifestFile writes the manifest to disk. -func writeManifestFile(key string, manifest *cacheManifest) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(manifest) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".manifest-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheManifestPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheManifestPath returns the on-disk path for a manifest key. -func cacheManifestPath(key string) string { - return filepath.Join(cacheDir(), key+".manifest.json") -} - -// manifestValid reports whether the manifest still matches current inputs. -func manifestValid(manifest *cacheManifest) bool { - if manifest == nil || manifest.Version != cacheVersion { - return false - } - if manifest.EnvHash == "" || len(manifest.Packages) == 0 { - return false - } - if len(manifest.ExtraFiles) > 0 { - current, err := buildCacheFilesFromMetaFunc(manifest.ExtraFiles) - if err != nil { - return false - } - if len(current) != len(manifest.ExtraFiles) { - return false - } - for i := range manifest.ExtraFiles { - if manifest.ExtraFiles[i] != current[i] { - return false - } - } - } - for i := range manifest.Packages { - pkg := manifest.Packages[i] - if pkg.ContentHash == "" { - return false - } - if len(pkg.RootFiles) == 0 || pkg.RootHash == "" { - return false - } - current, err := buildCacheFilesFromMetaFunc(pkg.Files) - if err != nil { - return false - } - if len(current) != len(pkg.Files) { - return false - } - for j := range pkg.Files { - if pkg.Files[j] != current[j] { - return false - } - } - rootCurrent, err := buildCacheFilesFromMetaFunc(pkg.RootFiles) - if err != nil { - return false - } - if len(rootCurrent) != len(pkg.RootFiles) { - return false - } - for j := range pkg.RootFiles { - if pkg.RootFiles[j] != rootCurrent[j] { - return false - } - } - rootPaths := make([]string, 0, len(pkg.RootFiles)) - for _, file := range pkg.RootFiles { - rootPaths = append(rootPaths, file.Path) - } - sort.Strings(rootPaths) - rootHash, err := hashFiles(rootPaths) - if err != nil || rootHash != pkg.RootHash { - return false - } - } - return true -} - -// buildCacheFilesFromMeta re-stats files to compare metadata. -func buildCacheFilesFromMeta(files []cacheFile) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, file := range files { - info, err := osStat(file.Path) - if err != nil { - return nil, err - } - out = append(out, cacheFile{ - Path: filepath.Clean(file.Path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }) - } - return out, nil -} - -// extraCacheFiles returns Go module/workspace files affecting builds. -func extraCacheFiles(wd string) []cacheFile { - paths := extraCachePathsFunc(wd) - if len(paths) == 0 { - return nil - } - out := make([]cacheFile, 0, len(paths)) - seen := make(map[string]struct{}) - for _, path := range paths { - path = filepath.Clean(path) - if _, ok := seen[path]; ok { - continue - } - info, err := osStat(path) - if err != nil { - continue - } - seen[path] = struct{}{} - out = append(out, cacheFile{ - Path: path, - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }) - } - sort.Slice(out, func(i, j int) bool { - return out[i].Path < out[j].Path - }) - return out -} - -// extraCachePaths finds go.mod/go.sum/go.work files for a working dir. -func extraCachePaths(wd string) []string { - var paths []string - dir := filepath.Clean(wd) - seen := make(map[string]struct{}) - for { - for _, name := range []string{"go.work", "go.work.sum", "go.mod", "go.sum"} { - full := filepath.Join(dir, name) - addExtraCachePath(&paths, seen, full) - } - parent := filepath.Dir(dir) - if parent == dir { - break - } - dir = parent - } - return paths -} - -// addExtraCachePath appends an existing file if it has not been seen. -func addExtraCachePath(paths *[]string, seen map[string]struct{}, full string) { - if _, ok := seen[full]; ok { - return - } - if _, err := osStat(full); err != nil { - return - } - *paths = append(*paths, full) - seen[full] = struct{}{} -} - -// sortedStrings returns a sorted copy of the input slice. -func sortedStrings(values []string) []string { - if len(values) == 0 { - return nil - } - out := append([]string(nil), values...) - sort.Strings(out) - return out -} - -// envHash returns a stable hash of environment variables. -func envHash(env []string) string { - if len(env) == 0 { - return "" - } - sorted := sortedStrings(env) - h := sha256.New() - for _, v := range sorted { - h.Write([]byte(v)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} diff --git a/internal/wire/cache_store.go b/internal/wire/cache_store.go deleted file mode 100644 index dce5565..0000000 --- a/internal/wire/cache_store.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "errors" - "io/fs" - "path/filepath" -) - -// cacheDir returns the base directory for Wire cache files. -func cacheDir() string { - return filepath.Join(osTempDir(), "wire-cache") -} - -// CacheDir returns the directory used for Wire's cache. -func CacheDir() string { - return cacheDir() -} - -// ClearCache removes all cached data. -func ClearCache() error { - return osRemoveAll(cacheDir()) -} - -// cachePath builds the on-disk path for a cached content hash. -func cachePath(key string) string { - return filepath.Join(cacheDir(), key+".bin") -} - -// readCache reads a cached content blob by key. -func readCache(key string) ([]byte, bool) { - data, err := osReadFile(cachePath(key)) - if err != nil { - return nil, false - } - return data, true -} - -// writeCache persists a content blob for the provided cache key. -func writeCache(key string, content []byte) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - path := cachePath(key) - tmp, err := osCreateTemp(dir, key+".tmp-") - if err != nil { - return - } - _, writeErr := tmp.Write(content) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), path); err != nil { - if errors.Is(err, fs.ErrExist) { - osRemove(tmp.Name()) - return - } - osRemove(tmp.Name()) - } -} diff --git a/internal/wire/cache_test.go b/internal/wire/cache_test.go deleted file mode 100644 index bc55bae..0000000 --- a/internal/wire/cache_test.go +++ /dev/null @@ -1,385 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestCacheInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - first, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("first Generate errors: %v", errs) - } - if len(first) != 1 || len(first[0].Content) == 0 { - t.Fatalf("first Generate returned unexpected result: %+v", first) - } - - pkgs, _, errs := load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load failed: %v", errs) - } - key, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - if cached, ok := readCache(key); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after first Generate") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - second, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("second Generate errors: %v", errs) - } - if len(second) != 1 || len(second[0].Content) == 0 { - t.Fatalf("second Generate returned unexpected result: %+v", second) - } - pkgs, _, errs = load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("reload failed: %v", errs) - } - key2, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage after update failed: %v", err) - } - if key2 == key { - t.Fatal("expected cache key to change after source update") - } - if cached, ok := readCache(key2); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after second Generate") - } -} - -func TestManifestInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after source update") - } -} - -func TestManifestInvalidationGoMod(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - goModPath := filepath.Join(root, "go.mod") - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0 // updated", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after go.mod update") - } -} - -func TestManifestInvalidationSameTimestamp(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - wirePath := filepath.Join(root, "app", "wire.go") - writeFile(t, wirePath, strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - info, err := os.Stat(wirePath) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - originalMod := info.ModTime() - - original, err := os.ReadFile(wirePath) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } - updated := strings.Replace(string(original), "ProvideMessage", "ProvideMassage", 1) - if len(updated) != len(original) { - t.Fatalf("expected updated content to keep length; got %d vs %d", len(updated), len(original)) - } - if err := os.WriteFile(wirePath, []byte(updated), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - if err := os.Chtimes(wirePath, originalMod, originalMod); err != nil { - t.Fatalf("Chtimes failed: %v", err) - } - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after same-timestamp content update") - } -} diff --git a/internal/wire/generate_package.go b/internal/wire/generate_package.go deleted file mode 100644 index de34aa6..0000000 --- a/internal/wire/generate_package.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "errors" - "fmt" - "go/format" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -// generateForPackage runs Wire code generation for a single package. -func generateForPackage(ctx context.Context, pkg *packages.Package, loader *lazyLoader, opts *GenerateOptions) GenerateResult { - if opts == nil { - opts = &GenerateOptions{} - } - pkgStart := time.Now() - res := GenerateResult{ - PkgPath: pkg.PkgPath, - } - dirStart := time.Now() - outDir, err := detectOutputDir(pkg.GoFiles) - logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - cacheKey, err := cacheKeyForPackage(pkg, opts) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - if cacheKey != "" { - cacheHitStart := time.Now() - if cached, ok := readCache(cacheKey); ok { - res.Content = cached - logTiming(ctx, "generate.package."+pkg.PkgPath+".cache_hit", cacheHitStart) - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res - } - } - oc := newObjectCache([]*packages.Package{pkg}, loader) - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - res.Errs = append(res.Errs, errs...) - return res - } else if loaded != nil { - pkg = loaded - } - g := newGen(pkg) - injectorStart := time.Now() - injectorFiles, errs := generateInjectors(oc, g, pkg) - logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) - if len(errs) > 0 { - res.Errs = errs - return res - } - copyStart := time.Now() - copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) - logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) - frameStart := time.Now() - goSrc := g.frame(opts.Tags) - logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) - if len(opts.Header) > 0 { - goSrc = append(opts.Header, goSrc...) - } - formatStart := time.Now() - fmtSrc, err := format.Source(goSrc) - logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) - if err != nil { - // This is likely a bug from a poorly generated source file. - // Add an error but also the unformatted source. - res.Errs = append(res.Errs, err) - } else { - goSrc = fmtSrc - } - res.Content = goSrc - if cacheKey != "" && len(res.Errs) == 0 { - writeCache(cacheKey, res.Content) - } - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res -} - -// allGeneratedOK reports whether every package result succeeded. -func allGeneratedOK(results []GenerateResult) bool { - if len(results) == 0 { - return false - } - for _, res := range results { - if len(res.Errs) > 0 { - return false - } - } - return true -} - -// detectOutputDir returns a shared directory for the provided file paths. -func detectOutputDir(paths []string) (string, error) { - if len(paths) == 0 { - return "", errors.New("no files to derive output directory from") - } - dir := filepath.Dir(paths[0]) - for _, p := range paths[1:] { - if dir2 := filepath.Dir(p); dir2 != dir { - return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) - } - } - return dir, nil -} diff --git a/internal/wire/generate_package_test.go b/internal/wire/generate_package_test.go deleted file mode 100644 index 51b15b1..0000000 --- a/internal/wire/generate_package_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateForPackageOptionAndDetectErrors(t *testing.T) { - res := generateForPackage(context.Background(), &packages.Package{PkgPath: "example.com/empty"}, nil, nil) - if len(res.Errs) == 0 { - t.Fatal("expected error for empty package") - } - if _, err := detectOutputDir(nil); err == nil { - t.Fatal("expected detectOutputDir error") - } -} - -func TestGenerateForPackageCacheKeyError(t *testing.T) { - tempDir := t.TempDir() - missing := filepath.Join(tempDir, "missing.go") - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{missing}, - } - res := generateForPackage(context.Background(), pkg, nil, &GenerateOptions{}) - if len(res.Errs) == 0 { - t.Fatal("expected cache key error") - } -} - -func TestGenerateForPackageCacheHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - key, err := cacheKeyForPackage(pkg, opts) - if err != nil || key == "" { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - writeCache(key, []byte("cached")) - res := generateForPackage(context.Background(), pkg, nil, opts) - if string(res.Content) != "cached" { - t.Fatalf("expected cached content, got %q", res.Content) - } -} - -func TestGenerateForPackageFormatError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - repoRoot := mustRepoRoot(t) - writeTempFile(t, tempDir, "go.mod", strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - appDir := filepath.Join(tempDir, "app") - if err := os.MkdirAll(appDir, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - writeTempFile(t, appDir, "wire.go", strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import \"github.com/goforj/wire\"", - "", - "func Init() string {", - "\twire.Build(NewMessage)", - "\treturn \"\"", - "}", - "", - "func NewMessage() string { return \"ok\" }", - "", - }, "\n")) - - ctx := context.Background() - env := append(os.Environ(), "GOWORK=off") - pkgs, loader, errs := load(ctx, tempDir, env, "", []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load errors: %v", errs) - } - opts := &GenerateOptions{Header: []byte("invalid")} - res := generateForPackage(ctx, pkgs[0], loader, opts) - if len(res.Errs) == 0 { - t.Fatal("expected format.Source error") - } -} - -func TestAllGeneratedOK(t *testing.T) { - if allGeneratedOK(nil) { - t.Fatal("expected empty results to be false") - } - if allGeneratedOK([]GenerateResult{{Errs: []error{context.DeadlineExceeded}}}) { - t.Fatal("expected errors to be false") - } - if !allGeneratedOK([]GenerateResult{{}}) { - t.Fatal("expected success results to be true") - } -} diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go new file mode 100644 index 0000000..39dd862 --- /dev/null +++ b/internal/wire/import_bench_test.go @@ -0,0 +1,1524 @@ +package wire + +import ( + "archive/tar" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + +const ( + importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" + importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" + importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" + importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" + importBenchProfile = "WIRE_IMPORT_BENCH_PROFILE" + importBenchProfileRun = "WIRE_IMPORT_BENCH_PROFILE_RUN" + importBenchVariant = "WIRE_IMPORT_BENCH_VARIANT" + importBenchCPUProfile = "WIRE_IMPORT_BENCH_CPU_PROFILE" + stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" + stockWireModulePath = "github.com/google/wire" + currentWireModulePath = "github.com/goforj/wire" +) + +type importBenchRow struct { + imports int + stockCold time.Duration + currentCold time.Duration + currentWarm time.Duration +} + +type importBenchScenarioRow struct { + profile string + localCount int + stdlibCount int + externalCount int + name string + stock time.Duration + current time.Duration +} + +type benchCaches struct { + home string + goCache string +} + +type benchGraphCounts struct { + local int + stdlib int + external int +} + +const importBenchTrials = 3 + +func TestPrintImportScaleBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchEnv) != "1" { + t.Skipf("%s not set", importBenchEnv) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) + + sizes := []int{10, 100, 1000} + rows := make([]importBenchRow, 0, len(sizes)) + for _, n := range sizes { + stockFixture := createImportBenchFixture(t, n, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, n, currentWireModulePath, repoRoot) + rows = append(rows, importBenchRow{ + imports: n, + stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)), + currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), + currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), + }) + } + printImportBenchTable(t, rows) +} + +func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchBreakdown) != "1" { + t.Skipf("%s not set", importBenchBreakdown) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) + + const imports = 1000 + stockFixture := createImportBenchFixture(t, imports, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, imports, currentWireModulePath, repoRoot) + + stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)) + currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) + currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) + + fmt.Printf("repo size: %d\n", imports) + fmt.Printf("stock cold: %s\n", formatMs(stockCold)) + fmt.Printf("current cold: %s\n", formatMs(currentCold)) + fmt.Printf("current unchanged: %s\n", formatMs(currentWarm)) + fmt.Printf("cold speedup: %s\n", formatSpeedup(stockCold, currentCold)) + fmt.Printf("unchanged speedup: %s\n", formatSpeedup(stockCold, currentWarm)) + fmt.Printf("cold gap: %s\n", formatMs(currentCold-stockCold)) + + prewarmGoBenchCache(t, currentFixture, currentCaches) + _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, currentCaches, "-timings") + fmt.Println("current cold timings:") + printScenarioTimingLines(output) +} + +func TestPrintImportScenarioBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchScenarios) != "1" { + t.Skipf("%s not set", importBenchScenarios) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + profiles := importBenchAppProfiles() + if filter := os.Getenv(importBenchProfile); filter != "" { + filtered := make([]appBenchProfile, 0, len(profiles)) + for _, profile := range profiles { + if profile.label == filter { + filtered = append(filtered, profile) + } + } + if len(filtered) == 0 { + t.Fatalf("%s=%q did not match any benchmark profile", importBenchProfile, filter) + } + profiles = filtered + } + rows := make([]importBenchScenarioRow, 0, len(profiles)*6) + for _, profile := range profiles { + shapeFixture := createAppShapeBenchFixture(t, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot) + shapeCounts := goListGraphCounts(t, shapeFixture, "example.com/appbench", newBenchCaches(t)) + rows = append(rows, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "cold run", + stock: medianDuration(runAppColdTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppColdTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "unchanged rerun", + stock: medianDuration(runAppWarmTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppWarmTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "body-only local edit", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "body", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "body", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "shape change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "shape", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "shape", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "import change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "import", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "import", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "known import toggle", + stock: medianDuration(runAppKnownToggleTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppKnownToggleTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + ) + } + printImportScenarioBenchTable(t, rows) +} + +func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchScenarioBD) != "1" { + t.Skipf("%s not set", importBenchScenarioBD) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + const ( + localPkgs = 10 + depPkgs = 1000 + ) + + stockPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, stockWireModulePath, stockDir) + currentPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, currentWireModulePath, repoRoot) + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) + + prewarmGoBenchCache(t, stockPkgDir, stockCaches) + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "shape") + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "base") + stockDur := runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + + prewarmGoBenchCache(t, currentPkgDir, currentCaches) + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "shape") + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "base") + currentDur, currentOutput := runWireBenchCommandOutput(t, currentBin, currentPkgDir, currentCaches, "-timings") + + fmt.Printf("scenario: local=%d dep=%d known import toggle\n", localPkgs, depPkgs) + fmt.Printf("stock: %s\n", formatMs(stockDur)) + fmt.Printf("current: %s\n", formatMs(currentDur)) + fmt.Printf("speedup: %s\n", formatSpeedup(stockDur, currentDur)) + fmt.Println("current timings:") + printScenarioTimingLines(currentOutput) +} + +func TestProfileCurrentWireScenarioRun(t *testing.T) { + if os.Getenv(importBenchProfileRun) != "1" { + t.Skipf("%s not set", importBenchProfileRun) + } + profile := os.Getenv(importBenchProfile) + variant := os.Getenv(importBenchVariant) + cpuProfile := os.Getenv(importBenchCPUProfile) + if profile == "" { + t.Fatalf("%s must be set", importBenchProfile) + } + if variant == "" { + t.Fatalf("%s must be set", importBenchVariant) + } + if cpuProfile == "" { + t.Fatalf("%s must be set", importBenchCPUProfile) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + profileCfg, err := importBenchAppProfile(profile) + if err != nil { + t.Fatal(err) + } + pkgDir := createAppShapeBenchFixture(t, profileCfg.localPkgs, profileCfg.depPkgs, profileCfg.external, currentWireModulePath, repoRoot) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + + switch variant { + case "unchanged": + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + case "body", "shape", "import": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + case "known-toggle": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + default: + t.Fatalf("unknown %s %q", importBenchVariant, variant) + } + + dur, output := runWireBenchCommandOutput(t, currentBin, pkgDir, caches, "-cpuprofile="+cpuProfile, "-timings") + fmt.Printf("profile: %s\n", profile) + fmt.Printf("variant: %s\n", variant) + fmt.Printf("duration: %s\n", formatMs(dur)) + fmt.Printf("cpuprofile: %s\n", cpuProfile) + printScenarioTimingLines(output) +} + +func BenchmarkCurrentWireLocalProfile(b *testing.B) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + b.Fatal(err) + } + currentBin := buildWireBinary(b, repoRoot, "current-wire") + const ( + features = 10 + depPkgs = 25 + external = false + ) + + for _, variant := range []string{"unchanged", "body", "shape", "import", "known-toggle"} { + b.Run(variant, func(b *testing.B) { + benchmarkCurrentWireAppScenario(b, currentBin, repoRoot, features, depPkgs, external, variant) + }) + } +} + +type appBenchProfile struct { + localPkgs int + depPkgs int + external bool + label string +} + +func importBenchAppProfiles() []appBenchProfile { + return []appBenchProfile{ + {localPkgs: 10, depPkgs: 25, label: "local"}, + {localPkgs: 10, depPkgs: 1000, label: "local-high"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, + } +} + +func importBenchAppProfile(label string) (appBenchProfile, error) { + for _, profile := range importBenchAppProfiles() { + if profile.label == label { + return profile, nil + } + } + return appBenchProfile{}, fmt.Errorf("%s=%q did not match any benchmark profile", importBenchProfile, label) +} + +func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + for i := 0; i < trials; i++ { + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func benchmarkCurrentWireAppScenario(b *testing.B, bin, repoRoot string, features, depPkgs int, external bool, variant string) { + b.Helper() + pkgDir := createAppShapeBenchFixture(b, features, depPkgs, external, currentWireModulePath, repoRoot) + caches := newBenchCaches(b) + root := filepath.Dir(pkgDir) + prewarmGoBenchCache(b, pkgDir, caches) + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + resetAppShapeBenchFixture(b, pkgDir, features) + switch variant { + case "body", "shape", "import": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, variant) + case "known-toggle": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "shape") + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "base") + case "unchanged": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + default: + b.Fatalf("unknown benchmark variant %q", variant) + } + b.StartTimer() + _ = runWireBenchCommand(b, bin, pkgDir, caches) + } +} + +func runAppWarmTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppScenarioTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir, variant string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func buildWireBinary(t testing.TB, dir, name string) string { + t.Helper() + if runtime.GOOS == "windows" && filepath.Ext(name) != ".exe" { + name += ".exe" + } + out := filepath.Join(t.TempDir(), name) + cmd := exec.Command("go", "build", "-o", out, "./cmd/wire") + cmd.Dir = dir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("build wire binary in %s: %v\n%s", dir, err, output) + } + return out +} + +func newBenchCaches(t testing.TB) benchCaches { + t.Helper() + return benchCaches{ + home: t.TempDir(), + goCache: filepath.Join(t.TempDir(), "gocache"), + } +} + +func extractStockWire(t testing.TB, repoRoot, commit string) string { + t.Helper() + tmp := t.TempDir() + cmd := exec.Command("git", "archive", "--format=tar", commit) + cmd.Dir = repoRoot + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatalf("git archive start: %v", err) + } + tr := tar.NewReader(stdout) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("read stock tar: %v", err) + } + target := filepath.Join(tmp, hdr.Name) + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, os.FileMode(hdr.Mode)); err != nil { + t.Fatalf("mkdir %s: %v", target, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + t.Fatalf("mkdir parent %s: %v", target, err) + } + f, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + t.Fatalf("create %s: %v", target, err) + } + if _, err := io.Copy(f, tr); err != nil { + _ = f.Close() + t.Fatalf("write %s: %v", target, err) + } + if err := f.Close(); err != nil { + t.Fatalf("close %s: %v", target, err) + } + } + } + if err := cmd.Wait(); err != nil { + t.Fatalf("git archive wait: %v", err) + } + return tmp +} + +func createImportBenchFixture(t testing.TB, imports int, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(importBenchGoMod(wireModulePath, wireReplaceDir)), 0o644); err != nil { + t.Fatal(err) + } + for i := 0; i < imports; i++ { + dir := filepath.Join(root, fmt.Sprintf("dep%04d", i)) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(importBenchDepFile(i, "base")), 0o644); err != nil { + t.Fatal(err) + } + } + if err := os.MkdirAll(filepath.Join(root, "app"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "app", "wire.go"), []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } + return filepath.Join(root, "app") +} + +func createAppShapeBenchFixture(t testing.TB, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + modulePath := "example.com/appbench" + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(appShapeGoMod(modulePath, wireModulePath, wireReplaceDir, external)), 0o644); err != nil { + t.Fatal(err) + } + if external { + seedAppShapeExternalGoSum(t, root) + } + for i := 0; i < depPkgs; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("dep%04d", i), "dep.go"), appShapeDepFile(i)) + } + writeAppShapeFile(t, filepath.Join(root, "internal", "logger", "logger.go"), appShapeLoggerFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "cache", "cache.go"), appShapeCacheFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "db", "db.go"), appShapeDBFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "config", "config.go"), appShapeConfigFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "metrics", "metrics.go"), appShapeMetricsFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "httpx", "httpx.go"), appShapeHTTPXFile(modulePath)) + if external { + writeAppShapeFile(t, filepath.Join(root, "internal", "extsink", "extsink.go"), appShapeExtSinkFile(modulePath)) + } + writeAppShapeFile(t, filepath.Join(root, "wire", "app.go"), appShapeAppFile(modulePath, features)) + writeAppShapeFile(t, filepath.Join(root, "wire", "wire.go"), appShapeWireFile(modulePath, wireModulePath, features, external)) + for i := 0; i < features; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("feature%04d", i), "feature.go"), appShapeFeatureFile(modulePath, wireModulePath, i, depPkgs, external)) + writeAppShapeControllerFile(t, root, i, "base") + } + return filepath.Join(root, "wire") +} + +func writeAppShapeFile(t testing.TB, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeAppShapeControllerFile(t testing.TB, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, "internal", fmt.Sprintf("feature%04d", index), "controller.go") + if err := os.WriteFile(path, []byte(appShapeControllerFile("example.com/appbench", index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + +func seedAppShapeExternalGoSum(t testing.TB, root string) { + t.Helper() + const source = "/private/tmp/test/go.sum" + data, err := os.ReadFile(source) + if err != nil { + return + } + if err := os.WriteFile(filepath.Join(root, "go.sum"), data, 0o644); err != nil { + t.Fatalf("write seeded go.sum: %v", err) + } +} + +func resetAppShapeBenchFixture(t testing.TB, pkgDir string, features int) { + t.Helper() + root := filepath.Dir(pkgDir) + for i := 0; i < features; i++ { + writeAppShapeControllerFile(t, root, i, "base") + } +} + +func appShapeGoMod(modulePath, wireModulePath, wireReplaceDir string, external bool) string { + extraRequires := "" + if external { + extraRequires = ` + github.com/alecthomas/kong v1.14.0 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 + github.com/glebarez/sqlite v1.11.0 + github.com/goforj/cache v0.1.5 + github.com/goforj/crypt v1.1.0 + github.com/goforj/env/v2 v2.3.0 + github.com/goforj/httpx v1.1.0 + github.com/goforj/null/v6 v6.0.2 + github.com/goforj/queue v0.1.5 + github.com/goforj/queue/driver/redisqueue v0.1.5 + github.com/goforj/scheduler v1.4.0 + github.com/goforj/storage v0.2.5 + github.com/goforj/storage/driver/localstorage v0.2.5 + github.com/goforj/storage/driver/redisstorage v0.2.5 + github.com/goforj/str v1.3.0 + github.com/google/go-cmp v0.6.0 + github.com/google/subcommands v1.2.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/hibiken/asynq v0.26.0 + github.com/imroc/req/v3 v3.57.0 + github.com/labstack/echo/v4 v4.15.1 + github.com/pmezard/go-difflib v1.0.0 + github.com/redis/go-redis/v9 v9.17.2 + github.com/rs/zerolog v1.34.0 + github.com/shirou/gopsutil/v4 v4.26.2 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 + golang.org/x/sys v0.41.0 + golang.org/x/term v0.40.0 + golang.org/x/tools v0.42.0 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.6.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1` + } + return fmt.Sprintf(`module %s + +go 1.19 + +require ( + %s v0.0.0%s +) + +replace %s => %s +`, modulePath, wireModulePath, extraRequires, wireModulePath, wireReplaceDir) +} + +func appShapeLoggerFile(modulePath string) string { + return `package logger + +import ( + "context" + "encoding/json" + "io" + "os" + "sync" + "time" +) + +type Logger struct { + sink io.Writer + mu sync.Mutex +} + +func NewLogger() *Logger { return &Logger{sink: os.Stdout} } + +func (l *Logger) Log(ctx context.Context, msg string, attrs map[string]string) { + l.mu.Lock() + defer l.mu.Unlock() + _, _ = json.Marshal(map[string]any{ + "ctx": ctx != nil, + "msg": msg, + "attrs": attrs, + "time": time.Now().UTC().Format(time.RFC3339Nano), + }) +} +` +} + +func appShapeCacheFile(modulePath string) string { + return `package cache + +type Manager struct{} + +func NewManager() *Manager { return &Manager{} } +` +} + +func appShapeDBFile(modulePath string) string { + return `package db + +import ( + "context" + "database/sql" + "net/url" + "path/filepath" +) + +type DB struct { + driver string + dsn string +} + +func NewDB() *DB { + _ = filepath.Join("var", "lib", "appbench") + _ = sql.LevelDefault + u := &url.URL{Scheme: "postgres", Host: "localhost", Path: "/appbench"} + return &DB{driver: "postgres", dsn: u.String()} +} + +func (db *DB) PingContext(context.Context) error { return nil } +` +} + +func appShapeDepFile(index int) string { + return fmt.Sprintf(`package dep%04d + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "path/filepath" + "strings" +) + +type Value struct { + Name string +} + +func Provide() Value { + sum := sha256.Sum256([]byte(fmt.Sprintf("dep-%%04d", %d))) + return Value{ + Name: filepath.Join("deps", strings.ToLower(hex.EncodeToString(sum[:])))[:16], + } +} +`, index, index) +} + +func appShapeConfigFile(modulePath string) string { + return `package config + +import ( + "encoding/json" + "os" + "strconv" +) + +type Config struct { + Port int + Service string +} + +func NewConfig() *Config { + cfg := &Config{Port: 8080, Service: "appbench"} + if v := os.Getenv("APPBENCH_PORT"); v != "" { + if port, err := strconv.Atoi(v); err == nil { + cfg.Port = port + } + } + _, _ = json.Marshal(cfg) + return cfg +} +` +} + +func appShapeMetricsFile(modulePath string) string { + return `package metrics + +import ( + "expvar" + "fmt" + "sync/atomic" +) + +type Metrics struct { + requests atomic.Int64 + name string +} + +func NewMetrics() *Metrics { + expvar.NewString("appbench_name").Set("appbench") + return &Metrics{name: fmt.Sprintf("appbench_%s", "requests")} +} +` +} + +func appShapeHTTPXFile(modulePath string) string { + return `package httpx + +import ( + "context" + "net/http" + "net/http/httptest" +) + +type Client struct { + client *http.Client +} + +func NewClient() *Client { + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + _ = req.WithContext(context.Background()) + return &Client{client: &http.Client{}} +} +` +} + +func appShapeExtSinkFile(modulePath string) string { + return `package extsink + +import ( + "context" + "fmt" + "os" + + _ "github.com/alecthomas/kong" + _ "github.com/charmbracelet/lipgloss" + _ "github.com/charmbracelet/lipgloss/table" + "github.com/fsnotify/fsnotify" + _ "github.com/glebarez/sqlite" + _ "github.com/goforj/cache" + _ "github.com/goforj/crypt" + _ "github.com/goforj/env/v2" + _ "github.com/goforj/httpx" + _ "github.com/goforj/null/v6" + _ "github.com/goforj/queue" + _ "github.com/goforj/queue/driver/redisqueue" + _ "github.com/goforj/scheduler" + _ "github.com/goforj/storage" + _ "github.com/goforj/storage/driver/localstorage" + _ "github.com/goforj/storage/driver/redisstorage" + _ "github.com/goforj/str" + "github.com/google/go-cmp/cmp" + "github.com/google/subcommands" + _ "github.com/google/uuid" + _ "github.com/gorilla/websocket" + _ "github.com/hibiken/asynq" + _ "github.com/imroc/req/v3" + _ "github.com/labstack/echo/v4" + _ "github.com/labstack/echo/v4/middleware" + "github.com/pmezard/go-difflib/difflib" + _ "github.com/redis/go-redis/v9" + _ "github.com/rs/zerolog" + _ "github.com/shirou/gopsutil/v4/cpu" + _ "github.com/shirou/gopsutil/v4/disk" + _ "github.com/shirou/gopsutil/v4/host" + _ "github.com/shirou/gopsutil/v4/mem" + _ "github.com/shirou/gopsutil/v4/net" + _ "github.com/shirou/gopsutil/v4/process" + "golang.org/x/mod/modfile" + _ "golang.org/x/net/http2" + _ "golang.org/x/net/http2/h2c" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" + _ "golang.org/x/term" + "golang.org/x/tools/go/packages" + _ "gopkg.in/yaml.v3" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" + _ "gorm.io/gorm" +) + +type Sink struct { + label string +} + +func NewSink() *Sink { + _ = cmp.Equal("a", "b") + _ = difflib.UnifiedDiff{} + _, _ = modfile.Parse("go.mod", []byte("module example.com/appbench"), nil) + _, _ = packages.Load(&packages.Config{Mode: packages.NeedName}, "fmt") + var g errgroup.Group + g.Go(func() error { return nil }) + _ = unix.Getpid() + _ = fsnotify.Event{Name: os.TempDir()} + _ = subcommands.ExitSuccess + return &Sink{label: fmt.Sprintf("sink:%v", context.Background() != nil)} +} +` +} + +func appShapeFeatureFile(modulePath, wireModulePath string, index, depPkgs int, external bool) string { + pkg := fmt.Sprintf("feature%04d", index) + var depImports strings.Builder + var depUse strings.Builder + for i := 0; i < depPkgs; i++ { + depImports.WriteString(fmt.Sprintf("\tdep%04d %q\n", i, fmt.Sprintf("%s/internal/dep%04d", modulePath, i))) + depUse.WriteString(fmt.Sprintf("\t_ = dep%04d.Provide()\n", i)) + } + externalImport := "" + externalArg := "" + externalField := "" + externalUse := "" + if external { + externalImport = fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink") + externalArg = ", sink *extsink.Sink" + externalField = "\tsink *extsink.Sink\n" + externalUse = "\t_ = sink\n" + } + return fmt.Sprintf(`package %s + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + wire %q + %q + %q + %q + %q + %q +%s +%s +) + +type Repo struct { + db *db.DB + config *config.Config + metrics *metrics.Metrics +%s} + +type Service struct { + repo *Repo + logger *logger.Logger + client *httpx.Client +} + +func NewRepo(dbConn *db.DB, cfg *config.Config, m *metrics.Metrics, l *logger.Logger%s) *Repo { + _, _ = json.Marshal(map[string]any{"feature": %d, "service": cfg.Service}) + l.Log(context.Background(), "repo.init", map[string]string{"feature": strconv.Itoa(%d)}) +%s return &Repo{db: dbConn, config: cfg, metrics: m} +} + +func NewService(repo *Repo, l *logger.Logger, client *httpx.Client) *Service { + _, _ = url.Parse(fmt.Sprintf("https://example.com/%%04d", %d)) + _ = time.Second + return &Service{repo: repo, logger: l, client: client} +} + +var Set = wire.NewSet(NewRepo, NewService, NewController) +`, pkg, wireModulePath, modulePath+"/internal/config", modulePath+"/internal/db", modulePath+"/internal/httpx", modulePath+"/internal/logger", modulePath+"/internal/metrics", depImports.String(), externalImport, externalField, externalArg, index, index, depUse.String()+externalUse, index) +} + +func appShapeControllerFile(modulePath string, index int, variant string) string { + pkg := fmt.Sprintf("feature%04d", index) + imports := []string{ + `"context"`, + `"fmt"`, + `"net/http"`, + `"strconv"`, + `"` + modulePath + `/internal/logger"`, + } + if variant == "shape" { + imports = append(imports, `"`+modulePath+`/internal/db"`) + } + if variant == "import" { + imports = append(imports, `"strings"`) + } + bodyLine := "" + switch variant { + case "body": + bodyLine = "\t_ = \"body-edit\"\n" + case "import": + bodyLine = "\t_ = strings.TrimSpace(\" import-edit \")\n" + } + extraField := "" + extraArg := "" + extraInit := "" + if variant == "shape" { + extraField = "\tdb *db.DB\n" + extraArg = ", d *db.DB" + extraInit = "\t\tdb: d,\n" + } + return fmt.Sprintf(`package %s + +import ( + %s +) + +type Controller struct { + logger *logger.Logger + service *Service +%s} + +func NewController(l *logger.Logger, s *Service%s) *Controller { +%s l.Log(context.Background(), "controller.init", map[string]string{"feature": strconv.Itoa(%d)}) + _ = http.MethodGet + _ = fmt.Sprintf("feature-%%d", %d) + return &Controller{ + logger: l, + service: s, +%s } +} +`, pkg, strings.Join(imports, "\n\t"), extraField, extraArg, bodyLine, index, index, extraInit) +} + +func appShapeAppFile(modulePath string, features int) string { + var b strings.Builder + b.WriteString("package wire\n\n") + if features > 0 { + b.WriteString("import (\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\tfeature%04d %q\n", i, fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + } + b.WriteString("type App struct{}\n\n") + b.WriteString("func NewApp(") + for i := 0; i < features; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("_ *feature%04d.Controller", i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n") + return b.String() +} + +func appShapeWireFile(modulePath, wireModulePath string, features int, external bool) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package wire\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/config")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/db")) + if external { + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink")) + } + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/httpx")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/logger")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/metrics")) + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t%q\n", fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(\n") + b.WriteString("\t\tconfig.NewConfig,\n") + b.WriteString("\t\tlogger.NewLogger,\n") + b.WriteString("\t\tdb.NewDB,\n") + if external { + b.WriteString("\t\textsink.NewSink,\n") + } + b.WriteString("\t\thttpx.NewClient,\n") + b.WriteString("\t\tmetrics.NewMetrics,\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t\tfeature%04d.Set,\n", i)) + } + b.WriteString("\t\tNewApp,\n\t)\n\treturn nil\n}\n") + return b.String() +} + +func runBodyEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "body") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runShapeEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "shape") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runImportChangeTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runKnownImportToggleTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runWireBenchCommand(t testing.TB, bin, pkgDir string, caches benchCaches) time.Duration { + t.Helper() + d, _ := runWireBenchCommandOutput(t, bin, pkgDir, caches) + return d +} + +func runWireBenchCommandOutput(t testing.TB, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { + t.Helper() + args := []string{"gen"} + args = append(args, extraArgs...) + cmd := exec.Command(bin, args...) + cmd.Dir = pkgDir + cmd.Env = append(benchEnv(caches.home, caches.goCache), "WIRE_LOADER_ARTIFACTS=1") + var stderr bytes.Buffer + cmd.Stdout = io.Discard + cmd.Stderr = &stderr + start := time.Now() + if err := cmd.Run(); err != nil { + t.Fatalf("run %s in %s: %v\n%s", bin, pkgDir, err, stderr.String()) + } + return time.Since(start), stderr.String() +} + +func prewarmGoBenchCache(t testing.TB, pkgDir string, caches benchCaches) { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prewarm go cache in %s: %v\n%s", pkgDir, err, output) + } +} + +func goListGraphCounts(t testing.TB, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "-json", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go list graph counts in %s: %v\n%s", pkgDir, err, output) + } + dec := json.NewDecoder(bytes.NewReader(output)) + seen := make(map[string]struct{}) + var counts benchGraphCounts + for { + var pkg struct { + ImportPath string + Standard bool + } + if err := dec.Decode(&pkg); err != nil { + if err == io.EOF { + break + } + t.Fatalf("decode graph counts for %s: %v", pkgDir, err) + } + if pkg.ImportPath == "" { + continue + } + if _, ok := seen[pkg.ImportPath]; ok { + continue + } + seen[pkg.ImportPath] = struct{}{} + switch { + case pkg.Standard: + counts.stdlib++ + case pkg.ImportPath == modulePath || strings.HasPrefix(pkg.ImportPath, modulePath+"/"): + counts.local++ + default: + counts.external++ + } + } + return counts +} + +func prepareBenchModule(t testing.TB, pkgDir string, caches benchCaches) { + t.Helper() + marker := filepath.Join(filepath.Dir(pkgDir), ".bench-module-ready") + if _, err := os.Stat(marker); err == nil { + return + } + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = filepath.Dir(pkgDir) + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prepare bench module in %s: %v\n%s", pkgDir, err, output) + } + if err := os.WriteFile(marker, []byte("ok\n"), 0o644); err != nil { + t.Fatalf("write module marker %s: %v", marker, err) + } +} + +func runColdTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runWarmTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func medianDuration(durations []time.Duration) time.Duration { + if len(durations) == 0 { + return 0 + } + sorted := append([]time.Duration(nil), durations...) + for i := 1; i < len(sorted); i++ { + for j := i; j > 0 && sorted[j] < sorted[j-1]; j-- { + sorted[j], sorted[j-1] = sorted[j-1], sorted[j] + } + } + return sorted[len(sorted)/2] +} + +func benchEnv(home, goCache string) []string { + env := append([]string(nil), os.Environ()...) + env = append(env, + "HOME="+home, + "GOCACHE="+goCache, + "GOMODCACHE="+benchModCache(), + "GOSUMDB=off", + ) + return env +} + +func benchModCache() string { + if path := os.Getenv("GOMODCACHE"); path != "" { + return path + } + return filepath.Join(os.TempDir(), "gomodcache") +} + +func importBenchGoMod(wireModulePath, wireReplaceDir string) string { + return fmt.Sprintf(`module example.com/importbench + +go 1.19 + +require %s v0.0.0 + +replace %s => %s +`, wireModulePath, wireModulePath, wireReplaceDir) +} + +func importBenchWireFile(imports int, wireModulePath string) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package app\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t%[1]q\n", fmt.Sprintf("example.com/importbench/dep%04d", i))) + } + b.WriteString(")\n\n") + b.WriteString("type App struct{}\n\n") + b.WriteString("func provideApp(") + for i := 0; i < imports; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("d%d *dep%04d.T", i, i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(wire.NewSet(\n") + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t\tdep%04d.Provide,\n", i)) + } + b.WriteString("\t\tprovideApp,\n\t))\n\treturn nil\n}\n") + return b.String() +} + +func importBenchDepFile(i int, variant string) string { + switch variant { + case "body": + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T {\n\t_ = \"body-edit\"\n\treturn &T{}\n}\n", i) + case "shape": + return fmt.Sprintf("package dep%04d\n\ntype T struct{ Extra int }\n\nfunc Provide() *T { return &T{} }\n", i) + default: + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) + } +} + +func writeImportBenchWireFile(t testing.TB, root string, imports int, wireModulePath string) { + t.Helper() + path := filepath.Join(root, "app", "wire.go") + if err := os.WriteFile(path, []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeImportBenchDepFile(t testing.TB, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, fmt.Sprintf("dep%04d", index), "dep.go") + if err := os.WriteFile(path, []byte(importBenchDepFile(index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + +func printImportBenchTable(t *testing.T, rows []importBenchRow) { + t.Helper() + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + fmt.Println("| repo size | stock | current cold | current unchanged | cold speedup | unchanged speedup |") + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + for _, row := range rows { + fmt.Printf("| %-9d | %-9s | %-12s | %-17s | %-12s | %-17s |\n", + row.imports, + formatMs(row.stockCold), + formatMs(row.currentCold), + formatMs(row.currentWarm), + formatSpeedup(row.stockCold, row.currentCold), + formatSpeedup(row.stockCold, row.currentWarm), + ) + } + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") +} + +func printImportScenarioBenchTable(t *testing.T, rows []importBenchScenarioRow) { + t.Helper() + profileWidth := len("profile") + localWidth := len("local") + stdlibWidth := len("stdlib") + externalWidth := len("external") + changeTypeWidth := len("change type") + stockWidth := len("stock") + currentWidth := len("current") + speedupWidth := len("speedup") + for _, row := range rows { + profileWidth = maxInt(profileWidth, len(row.profile)) + localWidth = maxInt(localWidth, len(fmt.Sprintf("%d", row.localCount))) + stdlibWidth = maxInt(stdlibWidth, len(fmt.Sprintf("%d", row.stdlibCount))) + externalWidth = maxInt(externalWidth, len(fmt.Sprintf("%d", row.externalCount))) + changeTypeWidth = maxInt(changeTypeWidth, len(row.name)) + stockWidth = maxInt(stockWidth, len(formatMs(row.stock))) + currentWidth = maxInt(currentWidth, len(formatMs(row.current))) + speedupWidth = maxInt(speedupWidth, len(formatSpeedup(row.stock, row.current))) + } + sep := fmt.Sprintf("+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+", + strings.Repeat("-", profileWidth), + strings.Repeat("-", localWidth), + strings.Repeat("-", stdlibWidth), + strings.Repeat("-", externalWidth), + strings.Repeat("-", changeTypeWidth), + strings.Repeat("-", stockWidth), + strings.Repeat("-", currentWidth), + strings.Repeat("-", speedupWidth), + ) + fmt.Println(sep) + fmt.Printf("| %*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, "profile", + localWidth, "local", + stdlibWidth, "stdlib", + externalWidth, "external", + changeTypeWidth, "change type", + stockWidth, "stock", + currentWidth, "current", + speedupWidth, "speedup", + ) + fmt.Println(sep) + for _, row := range rows { + fmt.Printf("| %*s | %-*d | %-*d | %-*d | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, row.profile, + localWidth, row.localCount, + stdlibWidth, row.stdlibCount, + externalWidth, row.externalCount, + changeTypeWidth, row.name, + stockWidth, formatMs(row.stock), + currentWidth, formatMs(row.current), + speedupWidth, formatSpeedup(row.stock, row.current), + ) + } + fmt.Println(sep) + fmt.Println() + fmt.Println("change types:") + fmt.Println(" cold run: first wire gen on a fresh Wire cache for that repo shape") + fmt.Println(" unchanged rerun: run wire gen again without changing any files") + fmt.Println(" body-only local edit: change local function body/content without changing imports, types, or constructor signatures") + fmt.Println(" shape change: change local provider/type shape such as constructor params, fields, or return shape") + fmt.Println(" import change: add or remove a local import, which can change discovered package shape") + fmt.Println(" known import toggle: switch back to a previously seen import/shape state in the same repo") +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func printScenarioTimingLines(output string) { + for _, line := range strings.Split(output, "\n") { + if !strings.Contains(line, "wire: timing:") { + continue + } + if strings.Contains(line, "loader.custom.root.discovery=") || + strings.Contains(line, "loader.discovery.") || + strings.Contains(line, "load.packages.load=") || + strings.Contains(line, "load.debug") || + strings.Contains(line, "loader.custom.typed.artifact_read=") || + strings.Contains(line, "loader.custom.typed.artifact_decode=") || + strings.Contains(line, "loader.custom.typed.artifact_import_link=") || + strings.Contains(line, "loader.custom.typed.artifact_write=") || + strings.Contains(line, "loader.custom.typed.root_load.wall=") || + strings.Contains(line, "loader.custom.typed.discovery.wall=") || + strings.Contains(line, "loader.custom.typed.artifact_hits=") || + strings.Contains(line, "loader.custom.typed.artifact_misses=") || + strings.Contains(line, "loader.custom.typed.artifact_writes=") || + strings.Contains(line, "generate.package.") || + strings.Contains(line, "wire.Generate=") || + strings.Contains(line, "total=") { + fmt.Println(line) + } + } +} + +func formatMs(d time.Duration) string { + return fmt.Sprintf("%.1fms", float64(d)/float64(time.Millisecond)) +} + +func formatSpeedup(oldDur, newDur time.Duration) string { + if newDur == 0 { + return "inf" + } + return fmt.Sprintf("%.2fx", float64(oldDur)/float64(newDur)) +} + +func TestImportBenchFixtureGenerates(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + bin := buildWireBinary(t, repoRoot, "fixture-wire") + fixture := createImportBenchFixture(t, 10, currentWireModulePath, repoRoot) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, fixture, caches) + _ = runWireBenchCommand(t, bin, fixture, caches) +} + +func TestImportBenchUsesStockArchive(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + check := exec.Command("git", "cat-file", "-e", stockWireCommit+"^{commit}") + check.Dir = repoRoot + if err := check.Run(); err != nil { + t.Skipf("stock archive commit %s not available in checkout", stockWireCommit) + } + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + if _, err := os.Stat(filepath.Join(stockDir, "cmd", "wire", "main.go")); err != nil { + t.Fatalf("stock archive missing cmd/wire: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "go", "list", "./cmd/wire") + cmd.Dir = stockDir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("stock archive not buildable: %v\n%s", err, out) + } +} + +func importBenchRepoRoot() (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Clean(filepath.Join(wd, "..", "..")), nil +} diff --git a/internal/wire/load_debug.go b/internal/wire/load_debug.go new file mode 100644 index 0000000..d3d5fc1 --- /dev/null +++ b/internal/wire/load_debug.go @@ -0,0 +1,329 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/packages" +) + +type parseFileStats struct { + mu sync.Mutex + calls int + primaryCalls int + depCalls int + cacheHits int + cacheMisses int + errors int + total time.Duration +} + +func (ps *parseFileStats) record(primary bool, dur time.Duration, err error, cacheHit bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.calls++ + if primary { + ps.primaryCalls++ + } else { + ps.depCalls++ + } + if cacheHit { + ps.cacheHits++ + } else { + ps.cacheMisses++ + } + ps.total += dur + if err != nil { + ps.errors++ + } +} + +func (ps *parseFileStats) snapshot() parseFileStats { + ps.mu.Lock() + defer ps.mu.Unlock() + return parseFileStats{ + calls: ps.calls, + primaryCalls: ps.primaryCalls, + depCalls: ps.depCalls, + cacheHits: ps.cacheHits, + cacheMisses: ps.cacheMisses, + errors: ps.errors, + total: ps.total, + } +} + +type loadScopeStats struct { + roots int + totalPackages int + compiledFiles int + syntaxFiles int + packagesWithSyntax int + packagesWithTypes int + packagesWithTypesInfo int + localPackages int + localSyntaxPackages int + externalPackages int + externalSyntaxPkgs int + unknownPackages int + topCompiled []string + topSyntax []string +} + +type packageMetric struct { + path string + count int +} + +func logLoadDebug(ctx context.Context, scope string, mode packages.LoadMode, subject string, wd string, pkgs []*packages.Package, parseStats *parseFileStats) { + if timing(ctx) == nil { + return + } + stats := summarizeLoadScope(wd, pkgs) + debugf(ctx, "load.debug scope=%s subject=%s mode=%s roots=%d total_pkgs=%d compiled_files=%d syntax_files=%d syntax_pkgs=%d typed_pkgs=%d types_info_pkgs=%d local_pkgs=%d local_syntax_pkgs=%d external_pkgs=%d external_syntax_pkgs=%d unknown_pkgs=%d", + scope, + subject, + formatLoadMode(mode), + stats.roots, + stats.totalPackages, + stats.compiledFiles, + stats.syntaxFiles, + stats.packagesWithSyntax, + stats.packagesWithTypes, + stats.packagesWithTypesInfo, + stats.localPackages, + stats.localSyntaxPackages, + stats.externalPackages, + stats.externalSyntaxPkgs, + stats.unknownPackages, + ) + if len(stats.topCompiled) > 0 { + debugf(ctx, "load.debug scope=%s top_compiled_files=%s", scope, strings.Join(stats.topCompiled, ", ")) + } + if len(stats.topSyntax) > 0 { + debugf(ctx, "load.debug scope=%s top_syntax_files=%s", scope, strings.Join(stats.topSyntax, ", ")) + } + if parseStats != nil { + snap := parseStats.snapshot() + debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.cumulative=%s", + scope, + snap.calls, + snap.primaryCalls, + snap.depCalls, + snap.cacheHits, + snap.cacheMisses, + snap.errors, + snap.total, + ) + } +} + +func summarizeLoadScope(wd string, pkgs []*packages.Package) loadScopeStats { + all := collectAllPackages(pkgs) + stats := loadScopeStats{ + roots: len(pkgs), + totalPackages: len(all), + } + moduleRoot := findModuleRoot(wd) + var compiled []packageMetric + var syntax []packageMetric + for _, pkg := range all { + if pkg == nil { + continue + } + compiledCount := len(pkg.CompiledGoFiles) + syntaxCount := len(pkg.Syntax) + stats.compiledFiles += compiledCount + stats.syntaxFiles += syntaxCount + if syntaxCount > 0 { + stats.packagesWithSyntax++ + } + if pkg.Types != nil { + stats.packagesWithTypes++ + } + if pkg.TypesInfo != nil { + stats.packagesWithTypesInfo++ + } + class := classifyPackageLocation(moduleRoot, pkg) + switch class { + case "local": + stats.localPackages++ + if syntaxCount > 0 { + stats.localSyntaxPackages++ + } + case "external": + stats.externalPackages++ + if syntaxCount > 0 { + stats.externalSyntaxPkgs++ + } + default: + stats.unknownPackages++ + } + if compiledCount > 0 { + compiled = append(compiled, packageMetric{path: pkg.PkgPath, count: compiledCount}) + } + if syntaxCount > 0 { + syntax = append(syntax, packageMetric{path: pkg.PkgPath, count: syntaxCount}) + } + } + stats.topCompiled = topPackageMetrics(compiled) + stats.topSyntax = topPackageMetrics(syntax) + return stats +} + +func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { + all := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), pkgs...) + for len(stack) > 0 { + p := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if p == nil || all[p.PkgPath] != nil { + continue + } + all[p.PkgPath] = p + for _, imp := range p.Imports { + stack = append(stack, imp) + } + } + return all +} + +func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { + if moduleRoot == "" || pkg == nil { + return "unknown" + } + for _, name := range pkg.CompiledGoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + for _, name := range pkg.GoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + return "unknown" +} + +func isWithinRoot(root, name string) bool { + cleanRoot := canonicalPath(root) + cleanName := canonicalPath(name) + if cleanName == cleanRoot { + return true + } + rel, err := filepath.Rel(cleanRoot, cleanName) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func canonicalPath(path string) string { + clean := filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(clean); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return clean +} + +func topPackageMetrics(metrics []packageMetric) []string { + sort.Slice(metrics, func(i, j int) bool { + if metrics[i].count == metrics[j].count { + return metrics[i].path < metrics[j].path + } + return metrics[i].count > metrics[j].count + }) + if len(metrics) > 5 { + metrics = metrics[:5] + } + out := make([]string, 0, len(metrics)) + for _, m := range metrics { + out = append(out, fmt.Sprintf("%s(%d)", m.path, m.count)) + } + return out +} + +func findModuleRoot(wd string) string { + dir := filepath.Clean(wd) + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return "" + } + dir = parent + } +} + +func formatLoadMode(mode packages.LoadMode) string { + flags := []struct { + bit packages.LoadMode + name string + }{ + {packages.NeedName, "NeedName"}, + {packages.NeedFiles, "NeedFiles"}, + {packages.NeedCompiledGoFiles, "NeedCompiledGoFiles"}, + {packages.NeedImports, "NeedImports"}, + {packages.NeedDeps, "NeedDeps"}, + {packages.NeedExportsFile, "NeedExportsFile"}, + {packages.NeedTypes, "NeedTypes"}, + {packages.NeedSyntax, "NeedSyntax"}, + {packages.NeedTypesInfo, "NeedTypesInfo"}, + {packages.NeedTypesSizes, "NeedTypesSizes"}, + {packages.NeedModule, "NeedModule"}, + {packages.NeedEmbedFiles, "NeedEmbedFiles"}, + {packages.NeedEmbedPatterns, "NeedEmbedPatterns"}, + } + var parts []string + for _, flag := range flags { + if mode&flag.bit != 0 { + parts = append(parts, flag.name) + } + } + if len(parts) == 0 { + return "0" + } + return strings.Join(parts, "|") +} + +func primaryFileSet(files map[string]struct{}) map[string]struct{} { + if len(files) == 0 { + return nil + } + out := make(map[string]struct{}, len(files)) + for name := range files { + out[filepath.Clean(name)] = struct{}{} + } + return out +} + +func isPrimaryFile(primary map[string]struct{}, filename string) bool { + if len(primary) == 0 { + return false + } + _, ok := primary[filepath.Clean(filename)] + return ok +} diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go deleted file mode 100644 index 1fbd96c..0000000 --- a/internal/wire/loader_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestLoadAndGenerateModule(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "app.go"), strings.Join([]string{ - "package app", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "noop", "noop.go"), strings.Join([]string{ - "package noop", - "", - "type Thing struct{}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } - - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 { - t.Fatalf("Generate returned %d results, want 1", len(gens)) - } - if len(gens[0].Errs) > 0 { - t.Fatalf("Generate result had errors: %v", gens[0].Errs) - } - if len(gens[0].Content) == 0 { - t.Fatal("Generate returned empty output for wire package") - } - if gens[0].OutputPath == "" { - t.Fatal("Generate returned empty output path") - } - - noops, errs := Generate(ctx, root, env, []string{"./noop"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate noop returned errors: %v", errs) - } - if len(noops) != 1 { - t.Fatalf("Generate noop returned %d results, want 1", len(noops)) - } - if len(noops[0].Errs) > 0 { - t.Fatalf("Generate noop result had errors: %v", noops[0].Errs) - } - if noops[0].OutputPath == "" { - t.Fatal("Generate noop returned empty output path") - } - if len(noops[0].Content) != 0 { - t.Fatal("Generate noop returned unexpected output") - } -} - -func mustRepoRoot(t *testing.T) string { - t.Helper() - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd failed: %v", err) - } - repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) - if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { - t.Fatalf("repo root not found at %s: %v", repoRoot, err) - } - return repoRoot -} - -func writeFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } -} diff --git a/internal/wire/loader_timing_bridge.go b/internal/wire/loader_timing_bridge.go new file mode 100644 index 0000000..2de0245 --- /dev/null +++ b/internal/wire/loader_timing_bridge.go @@ -0,0 +1,17 @@ +package wire + +import ( + "context" + "time" + + "github.com/goforj/wire/internal/loader" +) + +func withLoaderTiming(ctx context.Context) context.Context { + if t := timing(ctx); t != nil { + return loader.WithTiming(ctx, func(label string, d time.Duration) { + t(label, d) + }) + } + return ctx +} diff --git a/internal/wire/cache_hooks.go b/internal/wire/loader_validation.go similarity index 50% rename from internal/wire/cache_hooks.go rename to internal/wire/loader_validation.go index 9d4be6d..cde4d60 100644 --- a/internal/wire/cache_hooks.go +++ b/internal/wire/loader_validation.go @@ -15,27 +15,15 @@ package wire import ( - "encoding/json" - "os" -) - -var ( - osCreateTemp = os.CreateTemp - osMkdirAll = os.MkdirAll - osReadFile = os.ReadFile - osRemove = os.Remove - osRemoveAll = os.RemoveAll - osRename = os.Rename - osStat = os.Stat - osTempDir = os.TempDir + "context" - jsonMarshal = json.Marshal - jsonUnmarshal = json.Unmarshal - - cacheKeyForPackageFunc = cacheKeyForPackage - detectOutputDirFunc = detectOutputDir - buildCacheFilesFunc = buildCacheFiles - buildCacheFilesFromMetaFunc = buildCacheFilesFromMeta - rootPackageFilesFunc = rootPackageFiles - hashFilesFunc = hashFiles + "github.com/goforj/wire/internal/loader" ) + +func effectiveLoaderMode(_ context.Context, _ string, env []string) loader.Mode { + mode := loader.ModeFromEnv(env) + if mode != loader.ModeAuto { + return mode + } + return loader.ModeAuto +} diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go new file mode 100644 index 0000000..42fcaa4 --- /dev/null +++ b/internal/wire/output_cache.go @@ -0,0 +1,273 @@ +package wire + +import ( + "context" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + + "golang.org/x/tools/go/packages" + + "github.com/goforj/wire/internal/cachepaths" + "github.com/goforj/wire/internal/loader" +) + +const ( + outputCacheDirEnv = cachepaths.OutputCacheDirEnv + outputCacheEnabledEnv = "WIRE_OUTPUT_CACHE" +) + +type outputCacheEntry struct { + Version int + Content []byte +} + +type outputCacheCandidate struct { + path string + outputPath string +} + +func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, *loader.DiscoverySnapshot, bool) { + if !outputCacheEnabled(ctx, wd, env) { + debugf(ctx, "generate.output_cache=disabled") + return nil, nil, nil, false + } + rootResult, err := loader.New().LoadRootGraph(withLoaderTiming(ctx), loader.RootLoadRequest{ + WD: wd, + Env: env, + Tags: opts.Tags, + Patterns: append([]string(nil), patterns...), + NeedDeps: true, + Mode: effectiveLoaderMode(ctx, wd, env), + }) + if err != nil || rootResult == nil || len(rootResult.Packages) == 0 { + if err != nil { + debugf(ctx, "generate.output_cache=load_root_error") + } else { + debugf(ctx, "generate.output_cache=no_roots") + } + return nil, nil, nil, false + } + candidates := make(map[string]outputCacheCandidate, len(rootResult.Packages)) + results := make([]GenerateResult, 0, len(rootResult.Packages)) + for _, pkg := range rootResult.Packages { + outDir, err := detectOutputDir(pkg.GoFiles) + if err != nil { + debugf(ctx, "generate.output_cache=bad_output_dir") + return candidates, nil, rootResult.Discovery, false + } + key, err := outputCacheKey(wd, opts, pkg) + if err != nil { + debugf(ctx, "generate.output_cache=key_error") + return candidates, nil, rootResult.Discovery, false + } + path, err := outputCachePath(env, key) + if err != nil { + debugf(ctx, "generate.output_cache=path_error") + return candidates, nil, rootResult.Discovery, false + } + candidates[pkg.PkgPath] = outputCacheCandidate{ + path: path, + outputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + } + entry, ok := readOutputCache(path) + if !ok { + debugf(ctx, "generate.output_cache=miss") + return candidates, nil, rootResult.Discovery, false + } + results = append(results, GenerateResult{ + PkgPath: pkg.PkgPath, + OutputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + Content: entry.Content, + }) + } + debugf(ctx, "generate.output_cache=hit") + return candidates, results, rootResult.Discovery, len(results) == len(rootResult.Packages) +} + +func writeGenerateOutputCache(candidates map[string]outputCacheCandidate, generated []GenerateResult) { + for _, gen := range generated { + candidate, ok := candidates[gen.PkgPath] + if !ok || candidate.path == "" || len(gen.Errs) > 0 || len(gen.Content) == 0 { + continue + } + _ = writeOutputCache(candidate.path, &outputCacheEntry{ + Version: 1, + Content: append([]byte(nil), gen.Content...), + }) + } +} + +func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { + if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { + return false + } + if envValue(env, outputCacheEnabledEnv) == "0" { + return false + } + return envValue(env, "WIRE_LOADER_ARTIFACTS") != "0" +} + +func outputCachePath(env []string, key string) (string, error) { + dir, err := outputCacheDir(env) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func outputCacheDir(env []string) (string, error) { + return cachepaths.Dir(env, outputCacheDirEnv, "output-cache") +} + +func readOutputCache(path string) (*outputCacheEntry, bool) { + f, err := os.Open(path) + if err != nil { + return nil, false + } + defer f.Close() + var entry outputCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, false + } + if entry.Version != 1 { + return nil, false + } + return &entry, true +} + +func writeOutputCache(path string, entry *outputCacheEntry) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func outputCacheKey(wd string, opts *GenerateOptions, root *packages.Package) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-output-cache-v1\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(canonicalWirePath(wd))) + sum.Write([]byte{'\n'}) + sum.Write(opts.Header) + sum.Write([]byte{'\n'}) + sum.Write([]byte(opts.Tags)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(root.PkgPath)) + sum.Write([]byte{'\n'}) + workspace := detectWireModuleRoot(wd) + pkgs := reachablePackages(root) + for _, pkg := range pkgs { + sum.Write([]byte(pkg.PkgPath)) + sum.Write([]byte{'\n'}) + if isLocalWirePackage(workspace, pkg) { + files := append([]string(nil), pkg.GoFiles...) + sort.Strings(files) + for _, name := range files { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + if pkg.PkgPath == root.PkgPath { + src, err := os.ReadFile(name) + if err != nil { + return "", err + } + sum.Write(src) + sum.Write([]byte{'\n'}) + } + } + continue + } + sum.Write([]byte(pkg.ExportFile)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func reachablePackages(root *packages.Package) []*packages.Package { + seen := map[string]bool{} + var out []*packages.Package + var walk func(*packages.Package) + walk = func(pkg *packages.Package) { + if pkg == nil || seen[pkg.PkgPath] { + return + } + seen[pkg.PkgPath] = true + out = append(out, pkg) + paths := make([]string, 0, len(pkg.Imports)) + for path := range pkg.Imports { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + walk(pkg.Imports[path]) + } + } + walk(root) + sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) + return out +} + +func isLocalWirePackage(workspace string, pkg *packages.Package) bool { + if pkg == nil || len(pkg.GoFiles) == 0 { + return false + } + dir := filepath.Dir(pkg.GoFiles[0]) + dir = canonicalWirePath(dir) + workspace = canonicalWirePath(workspace) + if dir == workspace { + return true + } + return len(dir) > len(workspace) && dir[:len(workspace)] == workspace && dir[len(workspace)] == filepath.Separator +} + +func detectWireModuleRoot(start string) string { + start = canonicalWirePath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalWirePath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} diff --git a/internal/wire/output_cache_test.go b/internal/wire/output_cache_test.go new file mode 100644 index 0000000..a74621b --- /dev/null +++ b/internal/wire/output_cache_test.go @@ -0,0 +1,38 @@ +package wire + +import ( + "context" + "testing" +) + +func TestOutputCacheEnabled(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + env []string + want bool + }{ + { + name: "enabled with artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=1"}, + want: true, + }, + { + name: "disabled without artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=0"}, + want: false, + }, + { + name: "disabled by dedicated env", + env: []string{"WIRE_LOADER_ARTIFACTS=1", "WIRE_OUTPUT_CACHE=0"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := outputCacheEnabled(ctx, t.TempDir(), tt.env); got != tt.want { + t.Fatalf("outputCacheEnabled(..., %v) = %v, want %v", tt.env, got, tt.want) + } + }) + } +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index fc1b353..2e9c428 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "go/ast" + "go/parser" "go/token" "go/types" "os" @@ -30,6 +31,8 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" + + "github.com/goforj/wire/internal/loader" ) // A providerSetSrc captures the source for a type provided by a ProviderSet. @@ -251,7 +254,7 @@ type Field struct { // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, tags, patterns) + pkgs, errs := load(ctx, wd, env, tags, patterns, nil) logTiming(ctx, "load.packages", loadStart) if len(errs) > 0 { return nil, errs @@ -264,19 +267,13 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCache(pkgs, loader) + oc := newObjectCache(pkgs) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { // The marker function package confuses analysis. continue } - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - ec.add(errs...) - continue - } else if loaded != nil { - pkg = loaded - } pkgStart := time.Now() scope := pkg.Types.Scope() setStart := time.Now() @@ -364,46 +361,49 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, *lazyLoader, []error) { +func load(ctx context.Context, wd string, env []string, tags string, patterns []string, discovery *loader.DiscoverySnapshot) ([]*packages.Package, []error) { fset := token.NewFileSet() - baseCfg := &packages.Config{ - Context: ctx, - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps, - Dir: wd, + loaderMode := effectiveLoaderMode(ctx, wd, env) + parseStats := &parseFileStats{} + loadStart := time.Now() + result, err := loader.New().LoadPackages(withLoaderTiming(ctx), loader.PackageLoadRequest{ + WD: wd, Env: env, - BuildFlags: []string{"-tags=wireinject"}, + Tags: tags, + Patterns: append([]string(nil), patterns...), + Mode: packages.LoadAllSyntax, + LoaderMode: loaderMode, Fset: fset, + Discovery: discovery, + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + start := time.Now() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + parseStats.record(false, time.Since(start), err, false) + return file, err + }, + }) + logTiming(ctx, "load.packages.load", loadStart) + var typedPkgs []*packages.Package + if result != nil { + typedPkgs = result.Packages + debugf(ctx, "load.packages.backend=%s", result.Backend) + if result.FallbackReason != loader.FallbackReasonNone { + debugf(ctx, "load.packages.fallback_reason=%s", result.FallbackReason) + if result.FallbackDetail != "" { + debugf(ctx, "load.packages.fallback_detail=%s", result.FallbackDetail) + } + } } - if len(tags) > 0 { - baseCfg.BuildFlags[0] += " " + tags - } - escaped := make([]string, len(patterns)) - for i := range patterns { - escaped[i] = "pattern=" + patterns[i] - } - baseLoadStart := time.Now() - pkgs, err := packages.Load(baseCfg, escaped...) - logTiming(ctx, "load.packages.base.load", baseLoadStart) + logLoadDebug(ctx, "typed", packages.LoadAllSyntax, strings.Join(patterns, ","), wd, typedPkgs, parseStats) if err != nil { - return nil, nil, []error{err} + return nil, []error{err} } - baseErrsStart := time.Now() - errs := collectLoadErrors(pkgs) - logTiming(ctx, "load.packages.base.collect_errors", baseErrsStart) + errs := collectLoadErrors(typedPkgs) + logTiming(ctx, "load.packages.collect_errors", loadStart) if len(errs) > 0 { - return nil, nil, errs - } - - baseFiles := collectPackageFiles(pkgs) - loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: tags, - fset: fset, - baseFiles: baseFiles, + return nil, errs } - return pkgs, loader, nil + return typedPkgs, nil } func collectLoadErrors(pkgs []*packages.Package) []error { @@ -456,7 +456,6 @@ type objectCache struct { packages map[string]*packages.Package objects map[objRef]objCacheEntry hasher typeutil.Hasher - loader *lazyLoader } type objRef struct { @@ -469,7 +468,7 @@ type objCacheEntry struct { errs []error } -func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { +func newObjectCache(pkgs []*packages.Package) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } @@ -478,10 +477,6 @@ func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), - loader: loader, - } - if oc.fset == nil && loader != nil { - oc.fset = loader.fset } // Depth-first search of all dependencies to gather import path to // packages.Package mapping. go/packages guarantees that for a single @@ -515,24 +510,6 @@ func (oc *objectCache) registerPackages(pkgs []*packages.Package, replace bool) } } -func (oc *objectCache) ensurePackage(pkgPath string) (*packages.Package, []error) { - if pkg := oc.packages[pkgPath]; pkg != nil && pkg.TypesInfo != nil && len(pkg.Syntax) > 0 { - return pkg, nil - } - if oc.loader == nil { - if pkg := oc.packages[pkgPath]; pkg != nil { - return pkg, nil - } - return nil, []error{fmt.Errorf("package %q is missing type information", pkgPath)} - } - loaded, errs := oc.loader.load(pkgPath) - if len(errs) > 0 { - return nil, errs - } - oc.registerPackages(loaded, true) - return oc.packages[pkgPath], nil -} - // get converts a Go object into a Wire structure. It may return a *Provider, an // *IfaceBinding, a *ProviderSet, a *Value, or a []*Field. func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { @@ -543,9 +520,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { if ent, cached := oc.objects[ref]; cached { return ent.val, append([]error(nil), ent.errs...) } - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } defer func() { oc.objects[ref] = objCacheEntry{ val: val, @@ -573,14 +547,160 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } +func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { + fields := make([]*types.Var, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + fields = append(fields, st.Field(i)) + } + return providerInputsForVars(fields) +} + +func providerInputsForVars(vars []*types.Var) []ProviderInput { + args := make([]ProviderInput, 0, len(vars)) + for _, v := range vars { + args = append(args, providerInputForVar(v)) + } + return args +} + +func providerInputForVar(v *types.Var) ProviderInput { + return ProviderInput{ + Type: v.Type(), + FieldName: v.Name(), + } +} + +func newField(parent types.Type, v *types.Var, includePointer bool) *Field { + return &Field{ + Parent: parent, + Name: v.Name(), + Pkg: v.Pkg(), + Pos: v.Pos(), + Out: fieldOutputTypes(v.Type(), includePointer), + } +} + +func typeAndPointer(typ types.Type) []types.Type { + return []types.Type{typ, applyTypePointers(typ, 1)} +} + +func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { + out := []types.Type{typ} + if includePointer { + out = append(out, applyTypePointers(typ, 1)) + } + return out +} + +func newStructProvider(typeName types.Object, out []types.Type) *Provider { + return &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: out, + } +} + +func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, fmt.Errorf("missing typed package for %s", importPath) + } + return pkg.Types.Scope().Lookup(name), nil +} + +func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + fn, ok := obj.(*types.Func) + if !ok || fn == nil { + return nil, fmt.Errorf("%s.%s is not a provider function", importPath, name) + } + return fn, nil +} + +func applyTypePointers(typ types.Type, count int) types.Type { + for i := 0; i < count; i++ { + typ = types.NewPointer(typ) + } + return typ +} + +func namedStructType(typeName types.Object) (types.Type, *types.Struct, bool) { + out := typeName.Type() + st, ok := out.Underlying().(*types.Struct) + return out, st, ok +} + +func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { + ptr, ok := parent.(*types.Pointer) + if !ok { + return nil, false, fmt.Errorf("parent type %s is not a pointer", types.TypeString(parent, nil)) + } + switch t := ptr.Elem().Underlying().(type) { + case *types.Pointer: + st, ok := t.Elem().Underlying().(*types.Struct) + if !ok { + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } + return st, true, nil + case *types.Struct: + return t, false, nil + default: + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } +} + +func lookupStructField(st *types.Struct, name string) *types.Var { + for i := 0; i < st.NumFields(); i++ { + if st.Field(i).Name() == name { + return st.Field(i) + } + } + return nil +} + +func requiredStructField(st *types.Struct, name string) (*types.Var, error) { + v := lookupStructField(st, name) + if v == nil { + return nil, fmt.Errorf("field %q not found", name) + } + return v, nil +} + +func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, int) { + for i := 0; i < st.NumFields(); i++ { + if strings.EqualFold(strconv.Quote(st.Field(i).Name()), quotedName) { + return st.Field(i), i + } + } + return nil, -1 +} + // varDecl finds the declaration that defines the given variable. func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. // Recommended by https://golang.org/s/types-tutorial pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return nil + } + return valueSpecForVar(oc.fset, pkg.Syntax, obj) +} + +func valueSpecForVar(fset *token.FileSet, files []*ast.File, obj *types.Var) *ast.ValueSpec { pos := obj.Pos() - for _, f := range pkg.Syntax { - tokenFile := oc.fset.File(f.Pos()) + for _, f := range files { + tokenFile := fset.File(f.Pos()) + if tokenFile == nil { + continue + } if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { path, _ := astutil.PathEnclosingInterval(f, pos, pos) for _, node := range path { @@ -698,15 +818,22 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast if len(ec.errors) > 0 { return nil, ec.errors } + if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { + return nil, errs + } + return pset, nil +} + +func (oc *objectCache) finalizeProviderSet(pset *ProviderSet) []error { var errs []error pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) if len(errs) > 0 { - return nil, errs + return errs } if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { - return nil, errs + return errs } - return pset, nil + return nil } // structArgType attempts to interpret an expression as a simple struct type. @@ -834,8 +961,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { // It will not support any new feature introduced after v0.2. Please use the new // wire.Struct syntax for those. func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { - out := typeName.Type() - st, ok := out.Underlying().(*types.Struct) + out, st, ok := namedStructType(typeName) if !ok { return nil, []error{fmt.Errorf("%v does not name a struct", typeName)} } @@ -846,14 +972,9 @@ func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) notePosition(fset.Position(pos), fmt.Errorf("using struct literal to inject %s is deprecated and will be removed in the next release; use wire.Struct instead", typeName.Type()))) - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: pos, - Args: make([]ProviderInput, st.NumFields()), - IsStruct: true, - Out: []types.Type{out, types.NewPointer(out)}, - } + provider := newStructProvider(typeName, typeAndPointer(out)) + provider.Pos = pos + provider.Args = make([]ProviderInput, st.NumFields()) for i := 0; i < st.NumFields(); i++ { f := st.Field(i) provider.Args[i] = ProviderInput{ @@ -894,36 +1015,19 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call stExpr := call.Args[0].(*ast.CallExpr) typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: typeName.Pos(), - IsStruct: true, - Out: []types.Type{structPtr.Elem(), structPtr}, - } + provider := newStructProvider(typeName, []types.Type{structPtr.Elem(), structPtr}) if allFields(call) { - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } + provider.Args = providerInputsForAllowedStructFields(st) } else { - provider.Args = make([]ProviderInput, len(call.Args)-1) + fields := make([]*types.Var, 0, len(call.Args)-1) for i := 1; i < len(call.Args); i++ { v, err := checkField(call.Args[i], st) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - provider.Args[i-1] = ProviderInput{ - Type: v.Type(), - FieldName: v.Name(), - } + fields = append(fields, v) } + provider.Args = providerInputsForVars(fields) } for i := 0; i < len(provider.Args); i++ { for j := 0; j < i; j++ { @@ -1089,22 +1193,10 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } - - var struc *types.Struct - isPtrToStruct := false - switch t := structPtr.Elem().Underlying().(type) { - case *types.Pointer: - struc, ok = t.Elem().Underlying().(*types.Struct) - if !ok { - return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(struc, nil))) - } - isPtrToStruct = true - case *types.Struct: - struc = t - default: + struc, isPtrToStruct, err := structFromFieldsParent(structPtr) + if err != nil { return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(t, nil))) + fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } if struc.NumFields() < len(call.Args)-1 { return nil, notePosition(fset.Position(call.Pos()), @@ -1117,19 +1209,7 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - out := []types.Type{v.Type()} - if isPtrToStruct { - // If the field is from a pointer to a struct, then - // wire.Fields also provides a pointer to the field. - out = append(out, types.NewPointer(v.Type())) - } - fields = append(fields, &Field{ - Parent: structPtr.Elem(), - Name: v.Name(), - Pkg: v.Pkg(), - Pos: v.Pos(), - Out: out, - }) + fields = append(fields, newField(structPtr.Elem(), v, isPtrToStruct)) } return fields, nil } @@ -1141,13 +1221,12 @@ func checkField(f ast.Expr, st *types.Struct) (*types.Var, error) { if !ok { return nil, fmt.Errorf("%v must be a string with the field name", f) } - for i := 0; i < st.NumFields(); i++ { - if strings.EqualFold(strconv.Quote(st.Field(i).Name()), b.Value) { - if isPrevented(st.Tag(i)) { - return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) - } - return st.Field(i), nil + v, i := lookupQuotedStructField(st, b.Value) + if v != nil { + if isPrevented(st.Tag(i)) { + return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) } + return v, nil } return nil, fmt.Errorf("%s is not a field of %s", b.Value, st.String()) } @@ -1322,5 +1401,11 @@ func bindShouldUsePointer(info *types.Info, call *ast.CallExpr) bool { fun := call.Fun.(*ast.SelectorExpr) // wire.Bind pkgName := fun.X.(*ast.Ident) // wire wireName := info.ObjectOf(pkgName).(*types.PkgName) // wire package - return wireName.Imported().Scope().Lookup("bindToUsePointer") != nil + if imported := wireName.Imported(); imported != nil { + if isWireImport(imported.Path()) { + return true + } + return imported.Scope().Lookup("bindToUsePointer") != nil + } + return false } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 516d1d5..7c7a3b7 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -333,18 +333,18 @@ func TestAllFields(t *testing.T) { } } -func TestObjectCacheEnsurePackage(t *testing.T) { +func TestNewObjectCacheRegistersPackages(t *testing.T) { t.Parallel() fset := token.NewFileSet() pkg := &packages.Package{PkgPath: "example.com/p", Fset: fset} - oc := newObjectCache([]*packages.Package{pkg}, nil) + oc := newObjectCache([]*packages.Package{pkg}) - if got, errs := oc.ensurePackage(pkg.PkgPath); len(errs) != 0 || got != pkg { - t.Fatalf("expected existing package without errors, got pkg=%v errs=%v", got, errs) + if got := oc.packages[pkg.PkgPath]; got != pkg { + t.Fatalf("expected package to be registered, got %v", got) } - if _, errs := oc.ensurePackage("missing.example.com"); len(errs) == 0 { - t.Fatal("expected missing package error") + if got := oc.packages["missing.example.com"]; got != nil { + t.Fatalf("expected missing package to remain absent, got %v", got) } } diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go deleted file mode 100644 index b3d7011..0000000 --- a/internal/wire/parser_lazy_loader.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/parser" - "go/token" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -type lazyLoader struct { - ctx context.Context - wd string - env []string - tags string - fset *token.FileSet - baseFiles map[string]map[string]struct{} -} - -func collectPackageFiles(pkgs []*packages.Package) map[string]map[string]struct{} { - all := collectAllPackages(pkgs) - out := make(map[string]map[string]struct{}, len(all)) - for path, pkg := range all { - if pkg == nil { - continue - } - files := make(map[string]struct{}, len(pkg.CompiledGoFiles)) - for _, name := range pkg.CompiledGoFiles { - files[filepath.Clean(name)] = struct{}{} - } - if len(files) > 0 { - out[path] = files - } - } - return out -} - -func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { - all := make(map[string]*packages.Package) - stack := append([]*packages.Package(nil), pkgs...) - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil || all[p.PkgPath] != nil { - continue - } - all[p.PkgPath] = p - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return all -} - -func (ll *lazyLoader) load(pkgPath string) ([]*packages.Package, []error) { - return ll.loadWithMode(pkgPath, ll.fullMode(), "load.packages.lazy.load") -} - -func (ll *lazyLoader) fullMode() packages.LoadMode { - return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax -} - -func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timingLabel string) ([]*packages.Package, []error) { - cfg := &packages.Config{ - Context: ll.ctx, - Mode: mode, - Dir: ll.wd, - Env: ll.env, - BuildFlags: []string{"-tags=wireinject"}, - Fset: ll.fset, - ParseFile: ll.parseFileFor(pkgPath), - } - if len(ll.tags) > 0 { - cfg.BuildFlags[0] += " " + ll.tags - } - loadStart := time.Now() - pkgs, err := packages.Load(cfg, "pattern="+pkgPath) - logTiming(ll.ctx, timingLabel, loadStart) - if err != nil { - return nil, []error{err} - } - errs := collectLoadErrors(pkgs) - if len(errs) > 0 { - return nil, errs - } - return pkgs, nil -} - -func (ll *lazyLoader) parseFileFor(pkgPath string) func(*token.FileSet, string, []byte) (*ast.File, error) { - primary := ll.baseFiles[pkgPath] - return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - mode := parser.SkipObjectResolution - if primary != nil { - if _, ok := primary[filepath.Clean(filename)]; ok { - mode = parser.ParseComments | parser.SkipObjectResolution - } - } - file, err := parser.ParseFile(fset, filename, src, mode) - if err != nil { - return nil, err - } - if primary == nil { - return file, nil - } - if _, ok := primary[filepath.Clean(filename)]; ok { - return file, nil - } - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - fn.Body = nil - fn.Doc = nil - } - } - return file, nil - } -} diff --git a/internal/wire/parser_lazy_loader_test.go b/internal/wire/parser_lazy_loader_test.go deleted file mode 100644 index 31838ea..0000000 --- a/internal/wire/parser_lazy_loader_test.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/token" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestLazyLoaderParseFileFor(t *testing.T) { - t.Helper() - fset := token.NewFileSet() - pkgPath := "example.com/pkg" - root := t.TempDir() - primary := filepath.Join(root, "primary.go") - secondary := filepath.Join(root, "secondary.go") - ll := &lazyLoader{ - fset: fset, - baseFiles: map[string]map[string]struct{}{ - pkgPath: {filepath.Clean(primary): {}}, - }, - } - src := strings.Join([]string{ - "package pkg", - "", - "// Doc comment", - "func Foo() {", - "\tprintln(\"hi\")", - "}", - "", - }, "\n") - - parse := ll.parseFileFor(pkgPath) - file, err := parse(fset, primary, []byte(src)) - if err != nil { - t.Fatalf("parse primary: %v", err) - } - fn := firstFuncDecl(t, file) - if fn.Body == nil { - t.Fatal("expected primary file to keep function body") - } - if fn.Doc == nil { - t.Fatal("expected primary file to keep doc comment") - } - - file, err = parse(fset, secondary, []byte(src)) - if err != nil { - t.Fatalf("parse secondary: %v", err) - } - fn = firstFuncDecl(t, file) - if fn.Body != nil { - t.Fatal("expected secondary file to strip function body") - } - if fn.Doc != nil { - t.Fatal("expected secondary file to strip doc comment") - } -} - -func TestLoadModuleUsesWireinjectTagsForDeps(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep_inject.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package dep", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } -} - -func firstFuncDecl(t *testing.T, file *ast.File) *ast.FuncDecl { - t.Helper() - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - return fn - } - } - t.Fatal("expected function declaration in file") - return nil -} diff --git a/internal/wire/profile_bench_test.go b/internal/wire/profile_bench_test.go new file mode 100644 index 0000000..31dc7b7 --- /dev/null +++ b/internal/wire/profile_bench_test.go @@ -0,0 +1,32 @@ +package wire + +import ( + "context" + "os" + "testing" +) + +func BenchmarkGenerateRealAppWarmArtifacts(b *testing.B) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + b.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := b.TempDir() + env := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + ctx := context.Background() + + // Warm the artifact cache once before measurement. + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("warm Generate errors: %v", errs) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("Generate errors: %v", errs) + } + } +} diff --git a/internal/wire/time_compat.go b/internal/wire/time_compat.go new file mode 100644 index 0000000..6f0c9c4 --- /dev/null +++ b/internal/wire/time_compat.go @@ -0,0 +1,22 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import "time" + +var ( + timeNow = time.Now + timeSince = time.Since +) diff --git a/internal/wire/timing.go b/internal/wire/timing.go index 376d573..84c9022 100644 --- a/internal/wire/timing.go +++ b/internal/wire/timing.go @@ -16,6 +16,7 @@ package wire import ( "context" + "fmt" "time" ) @@ -49,3 +50,9 @@ func logTiming(ctx context.Context, label string, start time.Time) { t(label, time.Since(start)) } } + +func debugf(ctx context.Context, format string, args ...interface{}) { + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) + } +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index aa3efe3..9f5bb9e 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -21,10 +21,12 @@ import ( "context" "fmt" "go/ast" + "go/format" "go/printer" "go/token" "go/types" "io/ioutil" + "os" "path/filepath" "sort" "strconv" @@ -53,10 +55,27 @@ type GenerateResult struct { // Commit writes the generated file to disk. func (gen GenerateResult) Commit() error { + _, err := gen.CommitWithStatus() + return err +} + +// CommitWithStatus writes the generated file to disk when the content changed. +// It returns whether the file was written. +func (gen GenerateResult) CommitWithStatus() (bool, error) { if len(gen.Content) == 0 { - return nil + return false, nil + } + current, err := os.ReadFile(gen.OutputPath) + if err == nil && bytes.Equal(current, gen.Content) { + return false, nil } - return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666) + if err != nil && !os.IsNotExist(err) { + return false, err + } + if err := ioutil.WriteFile(gen.OutputPath, gen.Content, 0666); err != nil { + return false, err + } + return true, nil } // GenerateOptions holds options for Generate. @@ -83,25 +102,74 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - if cached, ok := readManifestResults(wd, env, patterns, opts); ok { + cacheCandidates, cached, discovery, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) + if ok { return cached, nil } loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, opts.Tags, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns, discovery) logTiming(ctx, "generate.load", loadStart) if len(errs) > 0 { return nil, errs } generated := make([]GenerateResult, len(pkgs)) for i, pkg := range pkgs { - generated[i] = generateForPackage(ctx, pkg, loader, opts) - } - if allGeneratedOK(generated) { - writeManifest(wd, env, patterns, opts, pkgs) + pkgStart := time.Now() + generated[i].PkgPath = pkg.PkgPath + dirStart := time.Now() + outDir, err := detectOutputDir(pkg.GoFiles) + logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + continue + } + generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") + g := newGen(pkg) + oc := newObjectCache([]*packages.Package{pkg}) + injectorStart := time.Now() + injectorFiles, genErrs := generateInjectors(oc, g, pkg) + logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) + if len(genErrs) > 0 { + generated[i].Errs = genErrs + continue + } + copyStart := time.Now() + copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) + logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) + frameStart := time.Now() + goSrc := g.frame(opts.Tags) + logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } + formatStart := time.Now() + fmtSrc, err := format.Source(goSrc) + logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + } else { + goSrc = fmtSrc + } + generated[i].Content = goSrc + logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) } + writeGenerateOutputCache(cacheCandidates, generated) return generated, nil } +func detectOutputDir(paths []string) (string, error) { + if len(paths) == 0 { + return "", fmt.Errorf("no files to derive output directory from") + } + dir := filepath.Dir(paths[0]) + for _, p := range paths[1:] { + if dir2 := filepath.Dir(p); dir2 != dir { + return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) + } + } + return dir, nil +} + // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 14080df..23db303 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -26,7 +26,9 @@ import ( "io/ioutil" "os" "os/exec" + "path" "path/filepath" + "sort" "strings" "testing" "unicode" @@ -111,6 +113,7 @@ func TestWire(t *testing.T) { t.Log(e.Error()) gotErrStrings[i] = scrubError(gopath, e.Error()) } + gotErrStrings = filterLegacyCompilerErrors(gotErrStrings) if !test.wantWireError { t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.") } @@ -191,6 +194,92 @@ func TestGenerateResultCommit(t *testing.T) { } } +func TestGenerateResultCommitWithStatus(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "wire_gen.go") + gen := GenerateResult{ + OutputPath: path, + Content: []byte("package p\n"), + } + + wrote, err := gen.CommitWithStatus() + if err != nil { + t.Fatalf("first CommitWithStatus failed: %v", err) + } + if !wrote { + t.Fatal("expected first CommitWithStatus call to write") + } + + wrote, err = gen.CommitWithStatus() + if err != nil { + t.Fatalf("second CommitWithStatus failed: %v", err) + } + if wrote { + t.Fatal("expected second CommitWithStatus call to report unchanged") + } +} + +func TestGenerateRealAppArtifactParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + ctx := context.Background() + + run := func(env []string) ([]GenerateResult, []string) { + t.Helper() + gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) + errStrings := make([]string, len(errs)) + for i, err := range errs { + errStrings[i] = err.Error() + } + sort.Strings(errStrings) + return gens, errStrings + } + + baseGens, baseErrs := run(os.Environ()) + artifactEnv := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + _, warmErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, warmErrs); diff != "" { + t.Fatalf("artifact warm-up errors mismatch (-base +warm):\n%s", diff) + } + artifactGens, artifactErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, artifactErrs); diff != "" { + t.Fatalf("artifact errors mismatch (-base +artifact):\n%s", diff) + } + if len(baseGens) != len(artifactGens) { + t.Fatalf("generated file count = %d, want %d", len(artifactGens), len(baseGens)) + } + for i := range baseGens { + if baseGens[i].PkgPath != artifactGens[i].PkgPath { + t.Fatalf("generated package[%d] = %q, want %q", i, artifactGens[i].PkgPath, baseGens[i].PkgPath) + } + if diff := cmp.Diff(string(baseGens[i].Content), string(artifactGens[i].Content)); diff != "" { + t.Fatalf("generated content mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) + artifactGenErrs := comparableGenerateErrors(artifactGens[i].Errs) + if diff := cmp.Diff(baseGenErrs, artifactGenErrs); diff != "" { + t.Fatalf("generate errs mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + } +} + +func comparableGenerateErrors(errs []error) []string { + out := make([]string, len(errs)) + for i, err := range errs { + out[i] = err.Error() + } + sort.Strings(out) + return out +} + func TestZeroValue(t *testing.T) { t.Parallel() @@ -453,6 +542,7 @@ func isIdent(s string) bool { // "C:\GOPATH" and running on Windows, the string // "C:\GOPATH\src\foo\bar.go:15:4" would be rewritten to "foo/bar.go:x:y". func scrubError(gopath string, s string) string { + s = normalizeHeaderRelativeError(s) sb := new(strings.Builder) query := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator) for { @@ -489,7 +579,106 @@ func scrubError(gopath string, s string) string { sb.WriteString(linecol) s = s[linecolLen:] } - return sb.String() + return strings.TrimRight(sb.String(), "\n") +} + +func normalizeHeaderRelativeError(s string) string { + const headerPrefix = "-: # " + if !strings.HasPrefix(s, headerPrefix) { + return s + } + pkgAndRest := strings.TrimPrefix(s, headerPrefix) + newline := strings.IndexByte(pkgAndRest, '\n') + if newline == -1 { + return s + } + pkg := strings.TrimSpace(pkgAndRest[:newline]) + rest := strings.TrimLeft(pkgAndRest[newline+1:], "\n") + if pkg == "" || rest == "" { + return s + } + + firstLineEnd := strings.IndexByte(rest, '\n') + if firstLineEnd == -1 { + firstLineEnd = len(rest) + } + firstLine := rest[:firstLineEnd] + rewritten, ok := canonicalizeRelativeErrorPath(pkg, firstLine) + if !ok { + return s + } + return normalizeLegacyUndefinedQualifiedName(rewritten + rest[firstLineEnd:]) +} + +func canonicalizeRelativeErrorPath(pkg, line string) (string, bool) { + goExt := strings.Index(line, ".go") + if goExt == -1 { + return "", false + } + goExt += len(".go") + linecol, n := scrubLineColumn(line[goExt:]) + if n == 0 { + return "", false + } + file := line[:goExt] + suffix := line[goExt+n:] + file = strings.ReplaceAll(file, "\\", "/") + file = strings.TrimPrefix(file, "./") + file = strings.TrimPrefix(file, "/") + baseDir := path.Base(pkg) + if strings.HasPrefix(file, pkg+"/") { + return file + linecol + suffix, true + } + if strings.HasPrefix(file, baseDir+"/") { + file = pkg + "/" + strings.TrimPrefix(file, baseDir+"/") + return file + linecol + suffix, true + } + if !strings.Contains(file, "/") { + return pkg + "/" + file + linecol + suffix, true + } + return "", false +} + +func normalizeLegacyUndefinedQualifiedName(s string) string { + const marker = ": undefined: " + idx := strings.Index(s, marker) + if idx == -1 { + return s + } + qualified := s[idx+len(marker):] + end := len(qualified) + for i, r := range qualified { + if r == '\n' || r == '\r' || r == '\t' || r == ' ' { + end = i + break + } + } + qualified = qualified[:end] + dot := strings.IndexByte(qualified, '.') + if dot == -1 || dot == 0 || dot == len(qualified)-1 { + return s + } + pkgName := qualified[:dot] + name := qualified[dot+1:] + if name == "" || !isLowerIdent(name) { + return s + } + return s[:idx] + ": name " + name + " not exported by package " + pkgName +} + +func isLowerIdent(s string) bool { + if s == "" { + return false + } + for i, r := range s { + if i == 0 && !unicode.IsLower(r) { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return false + } + } + return true } func scrubLineColumn(s string) (replacement string, n int) { @@ -521,6 +710,46 @@ func scrubLineColumn(s string) (replacement string, n int) { return ":x:y", n } +func filterLegacyCompilerErrors(errs []string) []string { + hasCanonicalPath := false + for _, err := range errs { + if strings.HasPrefix(err, "example.com/") { + hasCanonicalPath = true + break + } + } + if !hasCanonicalPath { + return errs + } + + filtered := errs[:0] + for _, err := range errs { + if strings.HasPrefix(err, "-: # ") { + continue + } + filtered = append(filtered, err) + } + return filtered +} + +func TestScrubErrorCanonicalizesHeaderRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\nfoo/wire.go:26:33: not enough arguments in call to wire.InterfaceValue") + want := "example.com/foo/wire.go:x:y: not enough arguments in call to wire.InterfaceValue" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + +func TestScrubErrorCanonicalizesHeaderRootRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\n/wire.go:27:17: name foo not exported by package bar") + want := "example.com/foo/wire.go:x:y: name foo not exported by package bar" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + type testCase struct { name string pkg string diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh new file mode 100755 index 0000000..232ccd9 --- /dev/null +++ b/scripts/import-benchmarks.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/import-benchmarks.sh table + scripts/import-benchmarks.sh scenarios [profile] + scripts/import-benchmarks.sh breakdown + +Commands: + table Print the 10/100/1000 import stock-vs-current benchmark table. + scenarios Print the stock-vs-current change-type scenario table. + Optional profiles: local, local-high, external-low, external-high. + breakdown Print a focused 1000-import cold/unchanged breakdown. +EOF +} + +case "${1:-}" in + table) + WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v + ;; + scenarios) + if [[ -n "${2:-}" ]]; then + WIRE_IMPORT_BENCH_SCENARIOS=1 WIRE_IMPORT_BENCH_PROFILE="${2}" go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + else + WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + fi + ;; + breakdown) + WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: ${1}" >&2 + usage >&2 + exit 1 + ;; +esac diff --git a/scripts/incremental-scenarios.sh b/scripts/incremental-scenarios.sh new file mode 100755 index 0000000..b59e970 --- /dev/null +++ b/scripts/incremental-scenarios.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/incremental-scenarios.sh test + scripts/incremental-scenarios.sh matrix + scripts/incremental-scenarios.sh table + scripts/incremental-scenarios.sh budgets + scripts/incremental-scenarios.sh bench + scripts/incremental-scenarios.sh large-table + scripts/incremental-scenarios.sh large-breakdown + scripts/incremental-scenarios.sh report + scripts/incremental-scenarios.sh all + +Commands: + test Run the full internal/wire test suite. + matrix Run the incremental scenario matrix correctness test. + table Print the incremental scenario timing table. + budgets Enforce the incremental scenario performance budgets. + bench Run the incremental scenario benchmark suite. + large-table Print the large-repo comparison timing table. + large-breakdown Print the large-repo shape-change breakdown table. + report Run the main timing report: scenario table, budgets, and large-repo table. + all Run matrix, table, budgets, and the large-repo table in sequence. +EOF +} + +print_section() { + local title="$1" + printf '\n== %s ==\n' "$title" +} + +print_test_table() { + local output_file="$1" + awk ' + /^\+[-+]+\+$/ { in_table=1 } + in_table && !/^--- PASS:/ && !/^PASS$/ && !/^ok[[:space:]]/ { print } + /^--- PASS:/ && in_table { exit } + ' "$output_file" +} + +run_test_table() { + local env_var="$1" + local test_name="$2" + local output_file + output_file="$(mktemp)" + env "$env_var"=1 go test ./internal/wire -run "$test_name" -count=1 -v >"$output_file" + print_test_table "$output_file" + rm -f "$output_file" +} + +run_test() { + go test ./internal/wire -count=1 +} + +run_matrix() { + go test ./internal/wire -run TestGenerateIncrementalScenarioMatrix -count=1 +} + +run_table() { + run_test_table WIRE_BENCH_SCENARIOS TestPrintIncrementalScenarioBenchmarkTable +} + +run_budgets() { + WIRE_PERF_BUDGETS=1 go test ./internal/wire -run TestIncrementalScenarioPerformanceBudgets -count=1 >/dev/null + echo "PASS" +} + +run_bench() { + go test ./internal/wire -run '^$' -bench BenchmarkGenerateIncrementalScenarioMatrix -benchmem -count=1 +} + +run_large_table() { + run_test_table WIRE_BENCH_TABLE TestPrintLargeRepoBenchmarkComparisonTable +} + +run_large_breakdown() { + run_test_table WIRE_BENCH_BREAKDOWN TestPrintLargeRepoShapeChangeBreakdownTable +} + +run_report() { + print_section "Scenario Timing Table" + run_table + print_section "Scenario Performance Budgets" + run_budgets + print_section "Large Repo Comparison Table" + run_large_table +} + +cmd="${1:-}" +case "$cmd" in + test) + run_test + ;; + matrix) + run_matrix + ;; + table) + run_table + ;; + budgets) + run_budgets + ;; + bench) + run_bench + ;; + large-table) + run_large_table + ;; + large-breakdown) + run_large_breakdown + ;; + report) + run_report + ;; + all) + run_matrix + run_report + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: $cmd" >&2 + usage >&2 + exit 1 + ;; +esac